41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
import numpy as np
|
|
from sklearn.neural_network import MLPRegressor
|
|
from sklearn.ensemble import RandomForestRegressor
|
|
|
|
class SurrogateHandler:
|
|
def __init__(self, model_type='mlp'):
|
|
self.model_type = model_type
|
|
self.is_trained = False
|
|
self.data_X = []
|
|
self.data_Y = []
|
|
|
|
# Model choice
|
|
if model_type == 'mlp':
|
|
self.model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
|
|
elif model_type == 'rf':
|
|
# RandomForest is generaly more robust "out of the box"
|
|
self.model = RandomForestRegressor(n_estimators=100, random_state=42)
|
|
else:
|
|
raise ValueError("Model type must be 'mlp' or 'rf'")
|
|
|
|
def add_data(self, x_matrix, f2_value):
|
|
# Flattening the position matrix to a 1 dimension vector
|
|
flat_x = np.array(x_matrix).flatten()
|
|
self.data_X.append(flat_x)
|
|
self.data_Y.append(f2_value)
|
|
|
|
def train(self):
|
|
if len(self.data_X) < 20: # No training if their is too few data
|
|
return
|
|
|
|
X = np.array(self.data_X)
|
|
y = np.array(self.data_Y)
|
|
self.model.fit(X, y)
|
|
self.is_trained = True
|
|
|
|
def predict(self, x_matrix):
|
|
if not self.is_trained:
|
|
return None
|
|
flat_x = np.array(x_matrix).flatten().reshape(1, -1)
|
|
return self.model.predict(flat_x)[0]
|