| | import gradio as gr |
| | import torch |
| |
|
| |
|
| | EXAMPLE_MD = """ |
| | ```python |
| | import torch |
| | |
| | t1 = torch.arange({n1}).view({dim1}) |
| | |
| | t2 = torch.arange({n2}).view({dim2}) |
| | |
| | (t1 @ t2).shape = {out_shape} |
| | |
| | ``` |
| | |
| | """ |
| |
|
| | matrix_loop = """```python |
| | out = 0 |
| | for i, j in zip(t1, t2): |
| | out += i * j |
| | ``` |
| | """ |
| |
|
| |
|
| | def generate_example(dim1: list, dim2: list): |
| | n1 = 1 |
| | n2 = 1 |
| | for i in dim1: |
| | n1 *= i |
| | for i in dim2: |
| | n2 *= i |
| |
|
| | t1 = torch.arange(n1).view(dim1) |
| | t2 = torch.arange(n2).view(dim2) |
| | try: |
| | out_shape = list((t1 @ t2).shape) |
| | except RuntimeError: |
| | out_shape = "error" |
| |
|
| | code = EXAMPLE_MD.format( |
| | n1=str(n1), dim1=str(dim1), n2=str(n2), dim2=str(dim2), out_shape=str(out_shape) |
| | ) |
| |
|
| | return dim1, dim2, code |
| |
|
| |
|
| | def sanitize_dimension(dim): |
| | if dim is None: |
| | gr.Error("one of the dimensions is empty, please fill it") |
| | if "[" in dim: |
| | dim = dim.replace("[", "") |
| | if "]" in dim: |
| | dim = dim.replace("]", "") |
| | if "," in dim: |
| | dim = dim.replace(",", " ").strip() |
| | out = [int(i.strip()) for i in dim.split()] |
| | else: |
| | out = [int(dim.strip())] |
| | if 0 in out: |
| | gr.Error( |
| | "Found the number 0 in one of the dimensions which is not allowed, consider using 1 instead" |
| | ) |
| | return out |
| |
|
| |
|
| | def create_row(dim, is_dim=None, checks=None, version=1): |
| | out = "| " |
| | n_dim = len(dim) |
| | for i in range(n_dim): |
| | if version == 1: |
| | |
| | if (is_dim == 1 and i == n_dim - 2) or (is_dim == 2 and i == n_dim - 1): |
| | color = "green" |
| | out += f"<strong style='color: {color}'> {dim[i]} </strong>| " |
| | |
| | elif (is_dim == 1 and i != n_dim - 1) or (is_dim == 2 and i == n_dim - 1): |
| | color = "green" if checks[i] == "V" else "red" |
| | out += f"<strong style='color: {color}'> {dim[i]} </strong>| " |
| | |
| | elif (is_dim == 1 and i == n_dim - 1) or (is_dim == 2 and i == n_dim - 2): |
| | color = "blue" if checks[i] == "V" else "yellow" |
| | out += f"<strong style='color: {color}'> {dim[i]} </strong>| " |
| | |
| | else: |
| | out += f"{dim[i]} | " |
| | if version == 2: |
| | if is_dim == 1 and i != n_dim - 1: |
| | out += f"<strong style='color: green'> {dim[i]} </strong>| " |
| | elif i == n_dim - 1: |
| | color = "blue" if checks[i] == "V" else "yellow" |
| | out += f"<strong style='color: {color}'> {dim[i]} </strong>| " |
| | else: |
| | out += f"{dim[i]} | " |
| |
|
| | return out + "\n" |
| |
|
| |
|
| | def create_header(n_dim, checks=None): |
| | checks = ["<!-- -->"] * n_dim if checks is None else checks |
| | out = "| " |
| | for i in checks: |
| | out = out + i + " | " |
| | out += "\n" + "|---" * n_dim + "|\n" |
| | return out |
| |
|
| |
|
| | def generate_table(dim1, dim2, checks=None, version=1): |
| | n_dim = len(dim1) |
| | table = create_header(n_dim, checks) |
| | |
| | if not checks: |
| | table += create_row(dim1) |
| | else: |
| | table += create_row(dim1, 1, checks, version) |
| |
|
| | |
| | if not checks: |
| | table += create_row(dim2) |
| | else: |
| | table += create_row(dim2, 2, checks, version) |
| | return table |
| |
|
| |
|
| | def alignment_and_fill_with_ones(dim1, dim2): |
| | n_dim = max(len(dim1), len(dim2)) |
| |
|
| | if len(dim1) == len(dim2): |
| | pass |
| | elif len(dim1) < len(dim2): |
| | placeholder = [1] * (n_dim - len(dim1)) |
| | placeholder.extend(dim1) |
| | dim1 = placeholder |
| | else: |
| | placeholder = [1] * (n_dim - len(dim2)) |
| | placeholder.extend(dim2) |
| | dim2 = placeholder |
| | return dim1, dim2 |
| |
|
| |
|
| | def check_validity(dim1, dim2): |
| | out = [] |
| | for i in range(len(dim1) - 2): |
| | if dim1[i] == dim2[i]: |
| | out.append("V") |
| | else: |
| | out.append("X") |
| | |
| | if dim1[-1] == dim2[-2]: |
| | out.extend(["V", "V"]) |
| | else: |
| | out.extend(["X", "X"]) |
| | return out |
| |
|
| |
|
| | def substitute_ones_with_concat(dim1, dim2, version=1): |
| | n = len(dim1) - 2 if version == 1 else len(dim1) - 1 |
| | for i in range(n): |
| | dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i] |
| | dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i] |
| | return dim1, dim2 |
| |
|
| |
|
| | def predict(dim1, dim2): |
| | dim1 = sanitize_dimension(dim1) |
| | dim2 = sanitize_dimension(dim2) |
| | n1, n2 = len(dim1), len(dim2) |
| | dim1, dim2, out = generate_example(dim1, dim2) |
| | |
| | if n1 > 1 and n2 > 1: |
| | |
| | dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2) |
| | table1 = generate_table(dim1, dim2) |
| | |
| | dim1, dim2 = substitute_ones_with_concat(dim1, dim2) |
| | table2 = generate_table(dim1, dim2) |
| | |
| | checks = check_validity(dim1, dim2) |
| | table3 = generate_table(dim1, dim2, checks) |
| |
|
| | out += "\n# Step1 (alignment and pre_append with ones)\n" + table1 |
| | out += ( |
| | "\n# Step2 (substitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" |
| | + table2 |
| | ) |
| | out += "\n# Step3 (check if matrix multiplication is valid)\n" |
| | out += "* last dimension of dim1 should equal before last dimension of dim2 (blue or yellow colors)\n" |
| | out += ( |
| | "* all the other dimensions should be equal to one another (green or red colors)\n\n" |
| | + table3 |
| | ) |
| | if "X" not in checks: |
| | dim1[-1] = dim2[-1] |
| | out += "\n# Final dimension\n" |
| | out += "as highlighted in <strong style='color:green'> green </strong> \n\n" |
| | out += f"`output.shape = {dim1}`" |
| | |
| | elif n1 == 1 and n2 == 1: |
| | out += "# Single Dimensional Cases\n" |
| | out += "When both matricies have only single dims they should both have the same number of values in the first dimension\n" |
| | out += "meaning that `t1.shape == t2.shape`\n" |
| | out += "the output is a single value, think : \n" |
| | out += matrix_loop |
| | else: |
| | out += "# One of the tensors has a single dimension\n" |
| | out += "In this case we need to assert that the last dimension of `t1` " |
| | out += "is equal to the last dimension of `t2`\n" |
| | out += "Once the assertion is valid then we get rid of the last dimension and keep the rest\n" |
| | out += "# Step 1 (alignment and fill with ones)\n" |
| | dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2) |
| | table = generate_table(dim1, dim2) |
| | out += table |
| | out += "\n# Step2 (susbtitute columns that have 1 with concat)\n" |
| | out += "fill all previous columns with ones\n" |
| | dim1, dim2 = substitute_ones_with_concat(dim1, dim2, 2) |
| | checks = ["V"] * (len(dim1) - 1) |
| | if dim1[-1] == dim2[-1]: |
| | checks.append("V") |
| | else: |
| | checks.append("X") |
| | table = generate_table(dim1, dim2, checks, 2) |
| | out += table |
| | if "X" not in checks: |
| | out += "\n#Final dimension" |
| | out += "The final dimension is everything colored in <strong style='color:green'> green </strong> \n" |
| | out += f"\nfinal dimension = `{dim1[:-1]}` " |
| |
|
| | return out |
| |
|
| |
|
| | demo = gr.Interface( |
| | predict, |
| | inputs=["text", "text"], |
| | outputs=["markdown"], |
| | examples=[ |
| | ["9,2,1,3,3", "5,3,7"], |
| | ["7,4,2,3", "5,2,7"], |
| | ["4,5,6,7", "7"], |
| | ["7,5,3", "4"], |
| | ["5", "5"], |
| | ["8", "2"], |
| | ], |
| | title= "Pytorch Matrix Multiplication", |
| | description= """There are 3 cases which are covered in the examples: |
| | * Both matricies have dimensions bigger than 1 |
| | * One of the matracies have a single dimension |
| | * Both Matracies have a single dimension |
| | """, |
| | ) |
| |
|
| | demo.launch(debug=True) |
| |
|