diff --git a/surrogate_handler.py b/surrogate_handler.py new file mode 100644 index 0000000..20d3bb5 --- /dev/null +++ b/surrogate_handler.py @@ -0,0 +1,40 @@ +import numpy as np +from sklearn.neural_network import MLPRegressor +from sklearn.ensemble import RandomForestRegressor + +class SurrogateHandler: + def __init__(self, model_type='mlp'): + self.model_type = model_type + self.is_trained = False + self.data_X = [] + self.data_Y = [] + + # Model choice + if model_type == 'mlp': + self.model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42) + elif model_type == 'rf': + # RandomForest is generaly more robust "out of the box" + self.model = RandomForestRegressor(n_estimators=100, random_state=42) + else: + raise ValueError("Model type must be 'mlp' or 'rf'") + + def add_data(self, x_matrix, f2_value): + # Flattening the position matrix to a 1 dimension vector + flat_x = np.array(x_matrix).flatten() + self.data_X.append(flat_x) + self.data_Y.append(f2_value) + + def train(self): + if len(self.data_X) < 20: # No training if their is too few data + return + + X = np.array(self.data_X) + y = np.array(self.data_Y) + self.model.fit(X, y) + self.is_trained = True + + def predict(self, x_matrix): + if not self.is_trained: + return None + flat_x = np.array(x_matrix).flatten().reshape(1, -1) + return self.model.predict(flat_x)[0]