Parthx10 commited on
Commit
59ac3bd
·
verified ·
1 Parent(s): ec9a389

Upload tapas_demo_1.ipynb

Browse files
Files changed (1) hide show
  1. tapas_demo_1.ipynb +961 -0
tapas_demo_1.ipynb ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "tapas-demo-1.ipynb",
7
+ "provenance": []
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "accelerator": "GPU"
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {
19
+ "id": "oKB8YaRk05Sl"
20
+ },
21
+ "source": [
22
+ "<a href=\"https://colab.research.google.com/github/google-research/tapas/blob/master/notebooks/sqa_predictions.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {
28
+ "id": "-07bRHwv0C7L"
29
+ },
30
+ "source": [
31
+ "##### Copyright 2020 The Google AI Language Team Authors\n",
32
+ "\n",
33
+ "Licensed under the Apache License, Version 2.0 (the \"License\");"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "metadata": {
39
+ "id": "SSpOxRRH0BCU"
40
+ },
41
+ "source": [
42
+ "# Copyright 2019 The Google AI Language Team Authors.\n",
43
+ "#\n",
44
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
45
+ "# you may not use this file except in compliance with the License.\n",
46
+ "# You may obtain a copy of the License at\n",
47
+ "#\n",
48
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
49
+ "#\n",
50
+ "# Unless required by applicable law or agreed to in writing, software\n",
51
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
52
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
53
+ "# See the License for the specific language governing permissions and\n",
54
+ "# limitations under the License."
55
+ ],
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {
62
+ "id": "j5EACclxE7sP"
63
+ },
64
+ "source": [
65
+ "Running a Tapas fine-tuned checkpoint\n",
66
+ "---\n",
67
+ "This notebook shows how to load and make predictions with TAPAS model, which was introduced in the paper: [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349)"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "metadata": {
73
+ "id": "Y-m_JoVCFCV0"
74
+ },
75
+ "source": [
76
+ "# Clone and install the repository\n"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {
82
+ "id": "lF84Z-KayR3Z"
83
+ },
84
+ "source": [
85
+ "First, let's fetch the code from the github repository and install it"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "metadata": {
91
+ "id": "uI6zyIM20Kw4",
92
+ "colab": {
93
+ "base_uri": "https://localhost:8080/"
94
+ },
95
+ "outputId": "88b1f69b-6d78-48d5-d9e6-116f9e58e792"
96
+ },
97
+ "source": [
98
+ "! git clone https://github.com/google-research/tapas.git"
99
+ ],
100
+ "execution_count": 1,
101
+ "outputs": [
102
+ {
103
+ "output_type": "stream",
104
+ "name": "stdout",
105
+ "text": [
106
+ "Cloning into 'tapas'...\n",
107
+ "remote: Enumerating objects: 822, done.\u001b[K\n",
108
+ "remote: Counting objects: 100% (240/240), done.\u001b[K\n",
109
+ "remote: Compressing objects: 100% (139/139), done.\u001b[K\n",
110
+ "remote: Total 822 (delta 119), reused 188 (delta 101), pack-reused 582\u001b[K\n",
111
+ "Receiving objects: 100% (822/822), 861.24 KiB | 3.19 MiB/s, done.\n",
112
+ "Resolving deltas: 100% (472/472), done.\n"
113
+ ]
114
+ }
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "metadata": {
120
+ "id": "PULx_0fmxbOh",
121
+ "colab": {
122
+ "base_uri": "https://localhost:8080/"
123
+ },
124
+ "outputId": "9fe9c21e-9cc6-445d-da84-400ac3979528"
125
+ },
126
+ "source": [
127
+ "! pip install ./tapas"
128
+ ],
129
+ "execution_count": 6,
130
+ "outputs": [
131
+ {
132
+ "output_type": "stream",
133
+ "name": "stdout",
134
+ "text": [
135
+ "Processing ./tapas\n",
136
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
137
+ "Collecting apache-beam[gcp]==2.28.0 (from tapas-table-parsing==0.0.1.dev0)\n",
138
+ " Using cached apache-beam-2.28.0.zip (2.4 MB)\n",
139
+ " \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
140
+ " \n",
141
+ " \u001b[31m×\u001b[0m \u001b[32mpython setup.py egg_info\u001b[0m did not run successfully.\n",
142
+ " \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
143
+ " \u001b[31m╰─>\u001b[0m See above for output.\n",
144
+ " \n",
145
+ " \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
146
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25herror\n",
147
+ "\u001b[1;31merror\u001b[0m: \u001b[1mmetadata-generation-failed\u001b[0m\n",
148
+ "\n",
149
+ "\u001b[31m×\u001b[0m Encountered error while generating package metadata.\n",
150
+ "\u001b[31m╰─>\u001b[0m See above for output.\n",
151
+ "\n",
152
+ "\u001b[1;35mnote\u001b[0m: This is an issue with the package mentioned above, not pip.\n",
153
+ "\u001b[1;36mhint\u001b[0m: See above for details.\n"
154
+ ]
155
+ }
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "markdown",
160
+ "metadata": {
161
+ "id": "7We9ofHuFMuk"
162
+ },
163
+ "source": [
164
+ "# Fetch models fom Google Storage"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {
170
+ "id": "sA1jUByqyUNB"
171
+ },
172
+ "source": [
173
+ "Next we can get pretrained checkpoint from Google Storage. For the sake of speed, this is base sized model trained on [SQA](https://www.microsoft.com/en-us/download/details.aspx?id=54253). Note that best results in the paper were obtained with with a large model, with 24 layers instead of 12."
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "metadata": {
179
+ "id": "B10C0Yz6gQyD",
180
+ "colab": {
181
+ "base_uri": "https://localhost:8080/"
182
+ },
183
+ "outputId": "5ae79305-934f-4fe5-ba0a-d0d23d3a4681"
184
+ },
185
+ "source": [
186
+ "! gsutil cp gs://tapas_models/2020_04_21/tapas_sqa_base.zip . && unzip tapas_sqa_base.zip"
187
+ ],
188
+ "execution_count": 3,
189
+ "outputs": [
190
+ {
191
+ "output_type": "stream",
192
+ "name": "stdout",
193
+ "text": [
194
+ "Copying gs://tapas_models/2020_04_21/tapas_sqa_base.zip...\n",
195
+ "/ [0 files][ 0.0 B/ 1.0 GiB] \r==> NOTE: You are downloading one or more large file(s), which would\n",
196
+ "run significantly faster if you enabled sliced object downloads. This\n",
197
+ "feature is enabled by default but requires that compiled crcmod be\n",
198
+ "installed (see \"gsutil help crcmod\").\n",
199
+ "\n",
200
+ "\\ [1 files][ 1.0 GiB/ 1.0 GiB] 86.5 MiB/s \n",
201
+ "Operation completed over 1 objects/1.0 GiB. \n",
202
+ "Archive: tapas_sqa_base.zip\n",
203
+ " creating: tapas_sqa_base/\n",
204
+ " inflating: tapas_sqa_base/model.ckpt.data-00000-of-00001 \n",
205
+ " inflating: tapas_sqa_base/model.ckpt.index \n",
206
+ " inflating: tapas_sqa_base/README.txt \n",
207
+ " inflating: tapas_sqa_base/vocab.txt \n",
208
+ " inflating: tapas_sqa_base/bert_config.json \n",
209
+ " inflating: tapas_sqa_base/model.ckpt.meta \n"
210
+ ]
211
+ }
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "metadata": {
217
+ "id": "E3107bGlGm7d"
218
+ },
219
+ "source": [
220
+ "# Imports"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "metadata": {
226
+ "id": "pnUjDlLqDd3m"
227
+ },
228
+ "source": [
229
+ "import tensorflow.compat.v1 as tf\n",
230
+ "import os\n",
231
+ "import shutil\n",
232
+ "import csv\n",
233
+ "import pandas as pd\n",
234
+ "import IPython\n",
235
+ "\n",
236
+ "tf.get_logger().setLevel('ERROR')"
237
+ ],
238
+ "execution_count": 4,
239
+ "outputs": []
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "metadata": {
244
+ "id": "aml6oLFl1dSt",
245
+ "colab": {
246
+ "base_uri": "https://localhost:8080/",
247
+ "height": 367
248
+ },
249
+ "outputId": "4bc24340-e13d-47c8-ff7e-4b276f7ab960"
250
+ },
251
+ "source": [
252
+ "from tapas.utils import tf_example_utils\n",
253
+ "from tapas.protos import interaction_pb2\n",
254
+ "from tapas.utils import number_annotation_utils\n",
255
+ "from tapas.scripts import prediction_utils"
256
+ ],
257
+ "execution_count": 5,
258
+ "outputs": [
259
+ {
260
+ "output_type": "error",
261
+ "ename": "ModuleNotFoundError",
262
+ "evalue": "No module named 'tapas.utils'",
263
+ "traceback": [
264
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
265
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
266
+ "\u001b[0;32m<ipython-input-5-c29ce12f712b>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtapas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtf_example_utils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtapas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprotos\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0minteraction_pb2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtapas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumber_annotation_utils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtapas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscripts\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mprediction_utils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
267
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tapas.utils'",
268
+ "",
269
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
270
+ ],
271
+ "errorDetails": {
272
+ "actions": [
273
+ {
274
+ "action": "open_url",
275
+ "actionText": "Open Examples",
276
+ "url": "/notebooks/snippets/importing_libraries.ipynb"
277
+ }
278
+ ]
279
+ }
280
+ }
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "metadata": {
286
+ "id": "AbMUYT1bKMp9"
287
+ },
288
+ "source": [
289
+ "# Load checkpoint for prediction"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "metadata": {
295
+ "id": "IO0d_wFMy82O"
296
+ },
297
+ "source": [
298
+ "Here's the prediction code, which will create and `interaction_pb2.Interaction` protobuf object, which is the datastructure we use to store examples, and then call the prediction script."
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "metadata": {
304
+ "id": "UKfxspnVFPsc"
305
+ },
306
+ "source": [
307
+ "os.makedirs('results/sqa/tf_examples', exist_ok=True)\n",
308
+ "os.makedirs('results/sqa/model', exist_ok=True)\n",
309
+ "with open('results/sqa/model/checkpoint', 'w') as f:\n",
310
+ " f.write('model_checkpoint_path: \"model.ckpt-0\"')\n",
311
+ "for suffix in ['.data-00000-of-00001', '.index', '.meta']:\n",
312
+ " shutil.copyfile(f'tapas_sqa_base/model.ckpt{suffix}', f'results/sqa/model/model.ckpt-0{suffix}')"
313
+ ],
314
+ "execution_count": null,
315
+ "outputs": []
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "metadata": {
320
+ "id": "qJDsiHLWGOoO"
321
+ },
322
+ "source": [
323
+ "df = pd.read_csv(\"/content/sales.csv\")"
324
+ ],
325
+ "execution_count": null,
326
+ "outputs": []
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "metadata": {
331
+ "id": "_ApMF3mAHitK"
332
+ },
333
+ "source": [
334
+ "df = df.astype(str)"
335
+ ],
336
+ "execution_count": null,
337
+ "outputs": []
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "metadata": {
342
+ "id": "x4sR97-zGyir",
343
+ "colab": {
344
+ "base_uri": "https://localhost:8080/",
345
+ "height": 355
346
+ },
347
+ "outputId": "3fb679c5-6391-4249-8037-518634782b7b"
348
+ },
349
+ "source": [
350
+ "df.head(10)"
351
+ ],
352
+ "execution_count": null,
353
+ "outputs": [
354
+ {
355
+ "output_type": "execute_result",
356
+ "data": {
357
+ "text/html": [
358
+ "<div>\n",
359
+ "<style scoped>\n",
360
+ " .dataframe tbody tr th:only-of-type {\n",
361
+ " vertical-align: middle;\n",
362
+ " }\n",
363
+ "\n",
364
+ " .dataframe tbody tr th {\n",
365
+ " vertical-align: top;\n",
366
+ " }\n",
367
+ "\n",
368
+ " .dataframe thead th {\n",
369
+ " text-align: right;\n",
370
+ " }\n",
371
+ "</style>\n",
372
+ "<table border=\"1\" class=\"dataframe\">\n",
373
+ " <thead>\n",
374
+ " <tr style=\"text-align: right;\">\n",
375
+ " <th></th>\n",
376
+ " <th>Pos</th>\n",
377
+ " <th>Player</th>\n",
378
+ " <th>Team</th>\n",
379
+ " <th>Span</th>\n",
380
+ " <th>Innings</th>\n",
381
+ " <th>Runs</th>\n",
382
+ " <th>Highest Score</th>\n",
383
+ " <th>Average</th>\n",
384
+ " <th>Strike Rate</th>\n",
385
+ " </tr>\n",
386
+ " </thead>\n",
387
+ " <tbody>\n",
388
+ " <tr>\n",
389
+ " <th>0</th>\n",
390
+ " <td>1</td>\n",
391
+ " <td>Sachin Tendulkar</td>\n",
392
+ " <td>India</td>\n",
393
+ " <td>1989-2012</td>\n",
394
+ " <td>452</td>\n",
395
+ " <td>18426</td>\n",
396
+ " <td>200</td>\n",
397
+ " <td>44.83</td>\n",
398
+ " <td>86.23</td>\n",
399
+ " </tr>\n",
400
+ " <tr>\n",
401
+ " <th>1</th>\n",
402
+ " <td>2</td>\n",
403
+ " <td>Kumar Sangakkara</td>\n",
404
+ " <td>Sri Lanka</td>\n",
405
+ " <td>2000-2015</td>\n",
406
+ " <td>380</td>\n",
407
+ " <td>14234</td>\n",
408
+ " <td>169</td>\n",
409
+ " <td>41.98</td>\n",
410
+ " <td>78.86</td>\n",
411
+ " </tr>\n",
412
+ " <tr>\n",
413
+ " <th>2</th>\n",
414
+ " <td>3</td>\n",
415
+ " <td>Ricky Ponting</td>\n",
416
+ " <td>Australia</td>\n",
417
+ " <td>1995-2012</td>\n",
418
+ " <td>365</td>\n",
419
+ " <td>13704</td>\n",
420
+ " <td>164</td>\n",
421
+ " <td>42.03</td>\n",
422
+ " <td>80.39</td>\n",
423
+ " </tr>\n",
424
+ " <tr>\n",
425
+ " <th>3</th>\n",
426
+ " <td>4</td>\n",
427
+ " <td>Sanath Jayasuriya</td>\n",
428
+ " <td>Sri Lanka</td>\n",
429
+ " <td>1989-2011</td>\n",
430
+ " <td>433</td>\n",
431
+ " <td>13430</td>\n",
432
+ " <td>189</td>\n",
433
+ " <td>32.36</td>\n",
434
+ " <td>91.2</td>\n",
435
+ " </tr>\n",
436
+ " <tr>\n",
437
+ " <th>4</th>\n",
438
+ " <td>5</td>\n",
439
+ " <td>Mahela Jayawardene</td>\n",
440
+ " <td>Sri Lanka</td>\n",
441
+ " <td>1998-2015</td>\n",
442
+ " <td>418</td>\n",
443
+ " <td>12650</td>\n",
444
+ " <td>144</td>\n",
445
+ " <td>33.37</td>\n",
446
+ " <td>78.96</td>\n",
447
+ " </tr>\n",
448
+ " <tr>\n",
449
+ " <th>5</th>\n",
450
+ " <td>6</td>\n",
451
+ " <td>Virat Kohli</td>\n",
452
+ " <td>India</td>\n",
453
+ " <td>2008-2020</td>\n",
454
+ " <td>236</td>\n",
455
+ " <td>11867</td>\n",
456
+ " <td>183</td>\n",
457
+ " <td>59.85</td>\n",
458
+ " <td>93.39</td>\n",
459
+ " </tr>\n",
460
+ " <tr>\n",
461
+ " <th>6</th>\n",
462
+ " <td>7</td>\n",
463
+ " <td>Inzamam-ul-Haq</td>\n",
464
+ " <td>Pakistan</td>\n",
465
+ " <td>1991-2007</td>\n",
466
+ " <td>350</td>\n",
467
+ " <td>11739</td>\n",
468
+ " <td>137</td>\n",
469
+ " <td>39.52</td>\n",
470
+ " <td>74.24</td>\n",
471
+ " </tr>\n",
472
+ " <tr>\n",
473
+ " <th>7</th>\n",
474
+ " <td>8</td>\n",
475
+ " <td>Jacques Kallis</td>\n",
476
+ " <td>South Africa</td>\n",
477
+ " <td>1996-2014</td>\n",
478
+ " <td>314</td>\n",
479
+ " <td>11579</td>\n",
480
+ " <td>139</td>\n",
481
+ " <td>44.36</td>\n",
482
+ " <td>72.89</td>\n",
483
+ " </tr>\n",
484
+ " <tr>\n",
485
+ " <th>8</th>\n",
486
+ " <td>9</td>\n",
487
+ " <td>Saurav Ganguly</td>\n",
488
+ " <td>India</td>\n",
489
+ " <td>1992-2007</td>\n",
490
+ " <td>300</td>\n",
491
+ " <td>11363</td>\n",
492
+ " <td>183</td>\n",
493
+ " <td>41.02</td>\n",
494
+ " <td>73.7</td>\n",
495
+ " </tr>\n",
496
+ " <tr>\n",
497
+ " <th>9</th>\n",
498
+ " <td>10</td>\n",
499
+ " <td>Rahul Dravid</td>\n",
500
+ " <td>India</td>\n",
501
+ " <td>1996-2011</td>\n",
502
+ " <td>318</td>\n",
503
+ " <td>10889</td>\n",
504
+ " <td>153</td>\n",
505
+ " <td>39.16</td>\n",
506
+ " <td>71.24</td>\n",
507
+ " </tr>\n",
508
+ " </tbody>\n",
509
+ "</table>\n",
510
+ "</div>"
511
+ ],
512
+ "text/plain": [
513
+ " Pos Player Team ... Highest Score Average Strike Rate\n",
514
+ "0 1 Sachin Tendulkar India ... 200 44.83 86.23\n",
515
+ "1 2 Kumar Sangakkara Sri Lanka ... 169 41.98 78.86\n",
516
+ "2 3 Ricky Ponting Australia ... 164 42.03 80.39\n",
517
+ "3 4 Sanath Jayasuriya Sri Lanka ... 189 32.36 91.2\n",
518
+ "4 5 Mahela Jayawardene Sri Lanka ... 144 33.37 78.96\n",
519
+ "5 6 Virat Kohli India ... 183 59.85 93.39\n",
520
+ "6 7 Inzamam-ul-Haq Pakistan ... 137 39.52 74.24\n",
521
+ "7 8 Jacques Kallis South Africa ... 139 44.36 72.89\n",
522
+ "8 9 Saurav Ganguly India ... 183 41.02 73.7\n",
523
+ "9 10 Rahul Dravid India ... 153 39.16 71.24\n",
524
+ "\n",
525
+ "[10 rows x 9 columns]"
526
+ ]
527
+ },
528
+ "metadata": {
529
+ "tags": []
530
+ },
531
+ "execution_count": 8
532
+ }
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "metadata": {
538
+ "id": "aQy0TbnHKgrw",
539
+ "colab": {
540
+ "base_uri": "https://localhost:8080/",
541
+ "height": 1000
542
+ },
543
+ "outputId": "44e94e26-2772-4f31-ab7e-38286ed3faaf"
544
+ },
545
+ "source": [
546
+ "list_of_list"
547
+ ],
548
+ "execution_count": null,
549
+ "outputs": [
550
+ {
551
+ "output_type": "execute_result",
552
+ "data": {
553
+ "text/plain": [
554
+ "[['Pos',\n",
555
+ " 'Player',\n",
556
+ " 'Team',\n",
557
+ " 'Span',\n",
558
+ " 'Innings',\n",
559
+ " 'Runs',\n",
560
+ " 'Highest Score',\n",
561
+ " 'Average',\n",
562
+ " 'Strike Rate'],\n",
563
+ " ['1',\n",
564
+ " 'Sachin Tendulkar',\n",
565
+ " 'India',\n",
566
+ " '1989-2012',\n",
567
+ " '452',\n",
568
+ " '18426',\n",
569
+ " '200',\n",
570
+ " '44.83',\n",
571
+ " '86.23'],\n",
572
+ " ['2',\n",
573
+ " 'Kumar Sangakkara',\n",
574
+ " 'Sri Lanka',\n",
575
+ " '2000-2015',\n",
576
+ " '380',\n",
577
+ " '14234',\n",
578
+ " '169',\n",
579
+ " '41.98',\n",
580
+ " '78.86'],\n",
581
+ " ['3',\n",
582
+ " 'Ricky Ponting',\n",
583
+ " 'Australia',\n",
584
+ " '1995-2012',\n",
585
+ " '365',\n",
586
+ " '13704',\n",
587
+ " '164',\n",
588
+ " '42.03',\n",
589
+ " '80.39'],\n",
590
+ " ['4',\n",
591
+ " 'Sanath Jayasuriya',\n",
592
+ " 'Sri Lanka',\n",
593
+ " '1989-2011',\n",
594
+ " '433',\n",
595
+ " '13430',\n",
596
+ " '189',\n",
597
+ " '32.36',\n",
598
+ " '91.2'],\n",
599
+ " ['5',\n",
600
+ " 'Mahela Jayawardene',\n",
601
+ " 'Sri Lanka',\n",
602
+ " '1998-2015',\n",
603
+ " '418',\n",
604
+ " '12650',\n",
605
+ " '144',\n",
606
+ " '33.37',\n",
607
+ " '78.96'],\n",
608
+ " ['6',\n",
609
+ " 'Virat Kohli',\n",
610
+ " 'India',\n",
611
+ " '2008-2020',\n",
612
+ " '236',\n",
613
+ " '11867',\n",
614
+ " '183',\n",
615
+ " '59.85',\n",
616
+ " '93.39'],\n",
617
+ " ['7',\n",
618
+ " 'Inzamam-ul-Haq',\n",
619
+ " 'Pakistan',\n",
620
+ " '1991-2007',\n",
621
+ " '350',\n",
622
+ " '11739',\n",
623
+ " '137',\n",
624
+ " '39.52',\n",
625
+ " '74.24'],\n",
626
+ " ['8',\n",
627
+ " 'Jacques Kallis',\n",
628
+ " 'South Africa',\n",
629
+ " '1996-2014',\n",
630
+ " '314',\n",
631
+ " '11579',\n",
632
+ " '139',\n",
633
+ " '44.36',\n",
634
+ " '72.89'],\n",
635
+ " ['9',\n",
636
+ " 'Saurav Ganguly',\n",
637
+ " 'India',\n",
638
+ " '1992-2007',\n",
639
+ " '300',\n",
640
+ " '11363',\n",
641
+ " '183',\n",
642
+ " '41.02',\n",
643
+ " '73.7'],\n",
644
+ " ['10',\n",
645
+ " 'Rahul Dravid',\n",
646
+ " 'India',\n",
647
+ " '1996-2011',\n",
648
+ " '318',\n",
649
+ " '10889',\n",
650
+ " '153',\n",
651
+ " '39.16',\n",
652
+ " '71.24']]"
653
+ ]
654
+ },
655
+ "metadata": {
656
+ "tags": []
657
+ },
658
+ "execution_count": 10
659
+ }
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "metadata": {
665
+ "id": "kpv469YIKgu1"
666
+ },
667
+ "source": [
668
+ "list_of_list = [[]]\n",
669
+ "list_of_list[0] = list(df.columns)\n",
670
+ "list_of_list.extend(df.values.tolist())"
671
+ ],
672
+ "execution_count": null,
673
+ "outputs": []
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "metadata": {
678
+ "id": "9RlvgDAmCNtP"
679
+ },
680
+ "source": [
681
+ "max_seq_length = 512\n",
682
+ "vocab_file = \"tapas_sqa_base/vocab.txt\"\n",
683
+ "config = tf_example_utils.ClassifierConversionConfig(\n",
684
+ " vocab_file=vocab_file,\n",
685
+ " max_seq_length=max_seq_length,\n",
686
+ " max_column_id=max_seq_length,\n",
687
+ " max_row_id=max_seq_length,\n",
688
+ " strip_column_names=False,\n",
689
+ " add_aggregation_candidates=False,\n",
690
+ ")\n",
691
+ "converter = tf_example_utils.ToClassifierTensorflowExample(config)\n",
692
+ "\n",
693
+ "def convert_interactions_to_examples(tables_and_queries):\n",
694
+ " \"\"\"Calls Tapas converter to convert interaction to example.\"\"\"\n",
695
+ " for idx, (table, queries) in enumerate(tables_and_queries):\n",
696
+ " interaction = interaction_pb2.Interaction()\n",
697
+ " for position, query in enumerate(queries):\n",
698
+ " question = interaction.questions.add()\n",
699
+ " question.original_text = query\n",
700
+ " question.id = f\"{idx}-0_{position}\"\n",
701
+ " for header in table[0]:\n",
702
+ " interaction.table.columns.add().text = header\n",
703
+ " for line in table[1:]:\n",
704
+ " row = interaction.table.rows.add()\n",
705
+ " for cell in line:\n",
706
+ " row.cells.add().text = cell\n",
707
+ " number_annotation_utils.add_numeric_values(interaction)\n",
708
+ " for i in range(len(interaction.questions)):\n",
709
+ " try:\n",
710
+ " yield converter.convert(interaction, i)\n",
711
+ " except ValueError as e:\n",
712
+ " print(f\"Can't convert interaction: {interaction.id} error: {e}\")\n",
713
+ "\n",
714
+ "def write_tf_example(filename, examples):\n",
715
+ " with tf.io.TFRecordWriter(filename) as writer:\n",
716
+ " for example in examples:\n",
717
+ " writer.write(example.SerializeToString())\n",
718
+ "\n",
719
+ "def predict(table_data, queries):\n",
720
+ " table = table_data\n",
721
+ " examples = convert_interactions_to_examples([(table, queries)])\n",
722
+ " write_tf_example(\"results/sqa/tf_examples/test.tfrecord\", examples)\n",
723
+ " write_tf_example(\"results/sqa/tf_examples/random-split-1-dev.tfrecord\", [])\n",
724
+ "\n",
725
+ " ! python tapas/tapas/run_task_main.py \\\n",
726
+ " --task=\"SQA\" \\\n",
727
+ " --output_dir=\"results\" \\\n",
728
+ " --noloop_predict \\\n",
729
+ " --test_batch_size={len(queries)} \\\n",
730
+ " --tapas_verbosity=\"ERROR\" \\\n",
731
+ " --compression_type= \\\n",
732
+ " --init_checkpoint=\"tapas_sqa_base/model.ckpt\" \\\n",
733
+ " --bert_config_file=\"tapas_sqa_base/bert_config.json\" \\\n",
734
+ " --mode=\"predict\" 2> error\n",
735
+ "\n",
736
+ "\n",
737
+ " results_path = \"results/sqa/model/test_sequence.tsv\"\n",
738
+ " all_coordinates = []\n",
739
+ " df = pd.DataFrame(table[1:], columns=table[0])\n",
740
+ " display(IPython.display.HTML(df.to_html(index=False)))\n",
741
+ " print()\n",
742
+ " with open(results_path) as csvfile:\n",
743
+ " reader = csv.DictReader(csvfile, delimiter='\\t')\n",
744
+ " for row in reader:\n",
745
+ " coordinates = prediction_utils.parse_coordinates(row[\"answer_coordinates\"])\n",
746
+ " all_coordinates.append(coordinates)\n",
747
+ " answers = ', '.join([table[row + 1][col] for row, col in coordinates])\n",
748
+ " position = int(row['position'])\n",
749
+ " print(\">\", queries[position])\n",
750
+ " print(answers)\n",
751
+ " return all_coordinates"
752
+ ],
753
+ "execution_count": null,
754
+ "outputs": []
755
+ },
756
+ {
757
+ "cell_type": "markdown",
758
+ "metadata": {
759
+ "id": "Gqu-I-M9QaoA"
760
+ },
761
+ "source": [
762
+ "# Predict"
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "code",
767
+ "metadata": {
768
+ "id": "SIE7bTJMVuSh",
769
+ "colab": {
770
+ "base_uri": "https://localhost:8080/",
771
+ "height": 618
772
+ },
773
+ "outputId": "a960fcd7-cdab-499c-d81c-5d0d8c1ccc69"
774
+ },
775
+ "source": [
776
+ "result = predict(list_of_list, [\"what were the players names?\",\n",
777
+ " \"of these, which team did Sachin Tendulkar play for?\",\n",
778
+ " \"what is his highest score?\",\n",
779
+ " \"how many runs has Virat Kohli scored?\"])"
780
+ ],
781
+ "execution_count": null,
782
+ "outputs": [
783
+ {
784
+ "output_type": "stream",
785
+ "text": [
786
+ "is_built_with_cuda: True\n",
787
+ "is_gpu_available: True\n",
788
+ "GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n",
789
+ "Training or predicting ...\n",
790
+ "Evaluation finished after training step 0.\n"
791
+ ],
792
+ "name": "stdout"
793
+ },
794
+ {
795
+ "output_type": "display_data",
796
+ "data": {
797
+ "text/html": [
798
+ "<table border=\"1\" class=\"dataframe\">\n",
799
+ " <thead>\n",
800
+ " <tr style=\"text-align: right;\">\n",
801
+ " <th>Pos</th>\n",
802
+ " <th>Player</th>\n",
803
+ " <th>Team</th>\n",
804
+ " <th>Span</th>\n",
805
+ " <th>Innings</th>\n",
806
+ " <th>Runs</th>\n",
807
+ " <th>Highest Score</th>\n",
808
+ " <th>Average</th>\n",
809
+ " <th>Strike Rate</th>\n",
810
+ " </tr>\n",
811
+ " </thead>\n",
812
+ " <tbody>\n",
813
+ " <tr>\n",
814
+ " <td>1</td>\n",
815
+ " <td>Sachin Tendulkar</td>\n",
816
+ " <td>India</td>\n",
817
+ " <td>1989-2012</td>\n",
818
+ " <td>452</td>\n",
819
+ " <td>18426</td>\n",
820
+ " <td>200</td>\n",
821
+ " <td>44.83</td>\n",
822
+ " <td>86.23</td>\n",
823
+ " </tr>\n",
824
+ " <tr>\n",
825
+ " <td>2</td>\n",
826
+ " <td>Kumar Sangakkara</td>\n",
827
+ " <td>Sri Lanka</td>\n",
828
+ " <td>2000-2015</td>\n",
829
+ " <td>380</td>\n",
830
+ " <td>14234</td>\n",
831
+ " <td>169</td>\n",
832
+ " <td>41.98</td>\n",
833
+ " <td>78.86</td>\n",
834
+ " </tr>\n",
835
+ " <tr>\n",
836
+ " <td>3</td>\n",
837
+ " <td>Ricky Ponting</td>\n",
838
+ " <td>Australia</td>\n",
839
+ " <td>1995-2012</td>\n",
840
+ " <td>365</td>\n",
841
+ " <td>13704</td>\n",
842
+ " <td>164</td>\n",
843
+ " <td>42.03</td>\n",
844
+ " <td>80.39</td>\n",
845
+ " </tr>\n",
846
+ " <tr>\n",
847
+ " <td>4</td>\n",
848
+ " <td>Sanath Jayasuriya</td>\n",
849
+ " <td>Sri Lanka</td>\n",
850
+ " <td>1989-2011</td>\n",
851
+ " <td>433</td>\n",
852
+ " <td>13430</td>\n",
853
+ " <td>189</td>\n",
854
+ " <td>32.36</td>\n",
855
+ " <td>91.2</td>\n",
856
+ " </tr>\n",
857
+ " <tr>\n",
858
+ " <td>5</td>\n",
859
+ " <td>Mahela Jayawardene</td>\n",
860
+ " <td>Sri Lanka</td>\n",
861
+ " <td>1998-2015</td>\n",
862
+ " <td>418</td>\n",
863
+ " <td>12650</td>\n",
864
+ " <td>144</td>\n",
865
+ " <td>33.37</td>\n",
866
+ " <td>78.96</td>\n",
867
+ " </tr>\n",
868
+ " <tr>\n",
869
+ " <td>6</td>\n",
870
+ " <td>Virat Kohli</td>\n",
871
+ " <td>India</td>\n",
872
+ " <td>2008-2020</td>\n",
873
+ " <td>236</td>\n",
874
+ " <td>11867</td>\n",
875
+ " <td>183</td>\n",
876
+ " <td>59.85</td>\n",
877
+ " <td>93.39</td>\n",
878
+ " </tr>\n",
879
+ " <tr>\n",
880
+ " <td>7</td>\n",
881
+ " <td>Inzamam-ul-Haq</td>\n",
882
+ " <td>Pakistan</td>\n",
883
+ " <td>1991-2007</td>\n",
884
+ " <td>350</td>\n",
885
+ " <td>11739</td>\n",
886
+ " <td>137</td>\n",
887
+ " <td>39.52</td>\n",
888
+ " <td>74.24</td>\n",
889
+ " </tr>\n",
890
+ " <tr>\n",
891
+ " <td>8</td>\n",
892
+ " <td>Jacques Kallis</td>\n",
893
+ " <td>South Africa</td>\n",
894
+ " <td>1996-2014</td>\n",
895
+ " <td>314</td>\n",
896
+ " <td>11579</td>\n",
897
+ " <td>139</td>\n",
898
+ " <td>44.36</td>\n",
899
+ " <td>72.89</td>\n",
900
+ " </tr>\n",
901
+ " <tr>\n",
902
+ " <td>9</td>\n",
903
+ " <td>Saurav Ganguly</td>\n",
904
+ " <td>India</td>\n",
905
+ " <td>1992-2007</td>\n",
906
+ " <td>300</td>\n",
907
+ " <td>11363</td>\n",
908
+ " <td>183</td>\n",
909
+ " <td>41.02</td>\n",
910
+ " <td>73.7</td>\n",
911
+ " </tr>\n",
912
+ " <tr>\n",
913
+ " <td>10</td>\n",
914
+ " <td>Rahul Dravid</td>\n",
915
+ " <td>India</td>\n",
916
+ " <td>1996-2011</td>\n",
917
+ " <td>318</td>\n",
918
+ " <td>10889</td>\n",
919
+ " <td>153</td>\n",
920
+ " <td>39.16</td>\n",
921
+ " <td>71.24</td>\n",
922
+ " </tr>\n",
923
+ " </tbody>\n",
924
+ "</table>"
925
+ ],
926
+ "text/plain": [
927
+ "<IPython.core.display.HTML object>"
928
+ ]
929
+ },
930
+ "metadata": {
931
+ "tags": []
932
+ }
933
+ },
934
+ {
935
+ "output_type": "stream",
936
+ "text": [
937
+ "\n",
938
+ "> what were the players names?\n",
939
+ "Sachin Tendulkar, Rahul Dravid, Jacques Kallis, Saurav Ganguly, Inzamam-ul-Haq, Sanath Jayasuriya, Ricky Ponting, Virat Kohli, Mahela Jayawardene, Kumar Sangakkara\n",
940
+ "> of these, which team did Sachin Tendulkar play for?\n",
941
+ "India\n",
942
+ "> what is his highest score?\n",
943
+ "200\n",
944
+ "> how many runs has Virat Kohli scored?\n",
945
+ "11867\n"
946
+ ],
947
+ "name": "stdout"
948
+ }
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "metadata": {
954
+ "id": "4WxPWUXVGh2W"
955
+ },
956
+ "source": [],
957
+ "execution_count": null,
958
+ "outputs": []
959
+ }
960
+ ]
961
+ }