| import os |
| import numpy as np |
| import tensorflow as tf |
| from tensorflow.keras.models import Sequential |
| from tensorflow.keras.layers import Dense, Flatten |
| from tensorflow.keras.utils import to_categorical |
|
|
| |
| CHAR_SET = '0123456789+-=* /' |
| NUM_CLASSES = len(CHAR_SET) |
| MAX_EQUATION_LENGTH = 2000000 |
| MAX_RESULT_LENGTH = 100000 |
|
|
| def one_hot_encode(s, max_length): |
| encoding = np.zeros((max_length, NUM_CLASSES)) |
| for i, char in enumerate(s[:max_length]): |
| if char in CHAR_SET: |
| char_index = CHAR_SET.index(char) |
| encoding[i, char_index] = 1 |
| return encoding |
|
|
| def read_dataset(directory): |
| data = [] |
| labels = [] |
| |
| for filename in os.listdir(directory): |
| if filename.endswith('.txt'): |
| with open(os.path.join(directory, filename), 'r') as file: |
| for line in file: |
| line = line.strip() |
| if '=' in line: |
| equation, result = line.split('=') |
| equation = equation.strip() |
| result = result.strip() |
| data.append(one_hot_encode(equation, MAX_EQUATION_LENGTH)) |
| labels.append(one_hot_encode(result, MAX_RESULT_LENGTH)) |
| |
| return np.array(data), np.array(labels) |
|
|
| |
| data, labels = read_dataset('.math_train') |
|
|
| |
| labels = labels.reshape((labels.shape[0], -1, NUM_CLASSES)) |
|
|
| |
| model = Sequential([ |
| Flatten(input_shape=(MAX_EQUATION_LENGTH, NUM_CLASSES)), |
| Dense(128, activation='relu'), |
| Dense(64, activation='relu'), |
| Dense(MAX_RESULT_LENGTH * NUM_CLASSES, activation='softmax') |
| ]) |
|
|
| model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) |
|
|
| |
| model.fit(data, labels.reshape((-1, MAX_RESULT_LENGTH * NUM_CLASSES)), epochs=50, batch_size=32) |
|
|
| |
| def solve_equation(model, equation): |
| encoded_equation = one_hot_encode(equation, MAX_EQUATION_LENGTH) |
| input_tensor = np.expand_dims(encoded_equation, axis=0) |
| prediction = model.predict(input_tensor) |
| predicted_indices = np.argmax(prediction.reshape((MAX_RESULT_LENGTH, NUM_CLASSES)), axis=-1) |
| predicted_chars = ''.join(CHAR_SET[i] for i in predicted_indices if i < len(CHAR_SET)) |
| return predicted_chars.strip() |
|
|
|
|
| equation = "1 + 1" |
| result = solve_equation(model, equation) |
| print(f"The result of '{equation}' is '{result}'") |