adding the surrogate handler
This commit is contained in:
40
surrogate_handler.py
Normal file
40
surrogate_handler.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
|
||||
class SurrogateHandler:
|
||||
def __init__(self, model_type='mlp'):
|
||||
self.model_type = model_type
|
||||
self.is_trained = False
|
||||
self.data_X = []
|
||||
self.data_Y = []
|
||||
|
||||
# Model choice
|
||||
if model_type == 'mlp':
|
||||
self.model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
|
||||
elif model_type == 'rf':
|
||||
# RandomForest is generaly more robust "out of the box"
|
||||
self.model = RandomForestRegressor(n_estimators=100, random_state=42)
|
||||
else:
|
||||
raise ValueError("Model type must be 'mlp' or 'rf'")
|
||||
|
||||
def add_data(self, x_matrix, f2_value):
|
||||
# Flattening the position matrix to a 1 dimension vector
|
||||
flat_x = np.array(x_matrix).flatten()
|
||||
self.data_X.append(flat_x)
|
||||
self.data_Y.append(f2_value)
|
||||
|
||||
def train(self):
|
||||
if len(self.data_X) < 20: # No training if their is too few data
|
||||
return
|
||||
|
||||
X = np.array(self.data_X)
|
||||
y = np.array(self.data_Y)
|
||||
self.model.fit(X, y)
|
||||
self.is_trained = True
|
||||
|
||||
def predict(self, x_matrix):
|
||||
if not self.is_trained:
|
||||
return None
|
||||
flat_x = np.array(x_matrix).flatten().reshape(1, -1)
|
||||
return self.model.predict(flat_x)[0]
|
||||
Reference in New Issue
Block a user