/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ratelimit;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Locale;
import java.util.Optional;
import java.util.Random;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.breaker.CircuitBreakerService;
import org.opensearch.timeseries.caching.TimeSeriesCache;
import org.opensearch.timeseries.constant.CommonMessages;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IndexableResult;
import org.opensearch.timeseries.model.TaskState;
import org.opensearch.timeseries.model.TaskType;
import org.opensearch.timeseries.model.TimeSeriesTask;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.ratelimit.FeatureRequest;
import org.opensearch.timeseries.ratelimit.RequestPriority;
import org.opensearch.timeseries.ratelimit.SaveResultStrategy;
import org.opensearch.timeseries.ratelimit.SingleRequestWorker;
import org.opensearch.timeseries.task.TaskCacheManager;
import org.opensearch.timeseries.task.TaskManager;
import org.opensearch.timeseries.util.ExceptionUtil;

public abstract class ColdStartWorker<RCFModelType extends ThresholdedRandomCutForest, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointDaoType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriteWorkerType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType>, ColdStarterType extends ModelColdStart<RCFModelType, IndexType, IndexManagementType, IndexableResultType>, CacheType extends TimeSeriesCache<RCFModelType>, IndexableResultType extends IndexableResult, IntermediateResultType extends IntermediateResult<IndexableResultType>, ModelManagerType extends ModelManager<RCFModelType, IndexableResultType, IntermediateResultType, IndexType, IndexManagementType, CheckpointDaoType, ColdStarterType>, SaveResultStrategyType extends SaveResultStrategy<IndexableResultType, IntermediateResultType>, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager<TaskCacheManagerType, TaskTypeEnum, TaskClass, IndexType, IndexManagementType>>
extends SingleRequestWorker<FeatureRequest> {
    private static final Logger LOG = LogManager.getLogger(ColdStartWorker.class);
    protected final ColdStarterType coldStarter;
    protected final CacheType cacheProvider;
    private final ModelManagerType modelManager;
    private final SaveResultStrategyType resultSaver;
    private final TaskManagerType taskManager;
    protected final CheckpointWriteWorkerType checkpointWriteWorker;

    public ColdStartWorker(String workerName, long heapSizeInBytes, int singleRequestSizeInBytes, Setting<Float> maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, Setting<Integer> concurrency, Duration executionTtl, ColdStarterType coldStarter, Duration stateTtl, NodeStateManager nodeStateManager, CacheType cacheProvider, AnalysisType context, ModelManagerType modelManager, SaveResultStrategyType resultSaver, TaskManagerType taskManager, CheckpointWriteWorkerType checkpointWriteWorker) {
        super(workerName, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, concurrency, executionTtl, stateTtl, nodeStateManager, context);
        this.coldStarter = coldStarter;
        this.cacheProvider = cacheProvider;
        this.modelManager = modelManager;
        this.resultSaver = resultSaver;
        this.taskManager = taskManager;
        this.checkpointWriteWorker = checkpointWriteWorker;
    }

    @Override
    protected void executeRequest(FeatureRequest coldStartRequest, ActionListener<Void> listener) {
        String configId = coldStartRequest.getConfigId();
        String modelId = coldStartRequest.getModelId();
        if (null == modelId) {
            String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest);
            LOG.warn(error);
            listener.onFailure((Exception)new RuntimeException(error));
            return;
        }
        ModelState<RCFModelType> modelState = this.createEmptyState(coldStartRequest, modelId, configId);
        ActionListener coldStartListener = ActionListener.wrap(r -> this.nodeStateManager.getConfig(configId, this.context, coldStartRequest.getTaskId() == null, (ActionListener<Optional<? extends Config>>)ActionListener.wrap(configOptional -> {
            try {
                if (!configOptional.isPresent()) {
                    LOG.error((Message)new ParameterizedMessage("fail to load trained model [{}] to cache due to the config not being found.", (Object)modelState.getModelId()));
                    return;
                }
                Config config = (Config)configOptional.get();
                if (modelState.getModel().isPresent()) {
                    String taskId = coldStartRequest.getTaskId();
                    if (r != null) {
                        for (int i = 0; i < r.size(); ++i) {
                            IndexableResult trainingResult = (IndexableResult)r.get(i);
                            this.resultSaver.saveResult((IndexableResult)trainingResult, config);
                        }
                    }
                    long dataStartTime = coldStartRequest.getDataStartTimeMillis();
                    Sample currentSample = new Sample(coldStartRequest.getCurrentFeature(), Instant.ofEpochMilli(dataStartTime), Instant.ofEpochMilli(dataStartTime + config.getIntervalInMilliseconds()));
                    Object result = ((ModelManager)this.modelManager).getResult(currentSample, modelState, modelId, config, taskId);
                    this.resultSaver.saveResult(result, config, coldStartRequest, modelId);
                    if (Strings.isEmpty((CharSequence)coldStartRequest.getTaskId())) {
                        boolean hosted = this.cacheProvider.hostIfPossible((Config)configOptional.get(), modelState);
                        LOG.debug((Message)(hosted ? new ParameterizedMessage("Loaded model {}.", (Object)modelState.getModelId()) : new ParameterizedMessage("Failed to load model {}.", (Object)modelState.getModelId())));
                        ((CheckpointWriteWorker)this.checkpointWriteWorker).write(modelState, true, RequestPriority.MEDIUM);
                    }
                } else {
                    String taskId = coldStartRequest.getTaskId();
                    if (taskId != null) {
                        HashMap<String, Object> updatedFields = new HashMap<String, Object>();
                        updatedFields.put("state", TaskState.INACTIVE.name());
                        updatedFields.put("error", CommonMessages.NOT_ENOUGH_DATA);
                        ((TaskManager)this.taskManager).updateTask(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(updateResponse -> LOG.info("Updated task {} for config {}", (Object)taskId, (Object)configId), e -> LOG.error("Failed to update task: {} for config: {}", (Object)taskId, (Object)configId, e)));
                    } else if (modelState.getSamples().size() > 0) {
                        ((CheckpointWriteWorker)this.checkpointWriteWorker).write(modelState, true, RequestPriority.MEDIUM);
                    }
                }
            }
            finally {
                listener.onResponse(null);
            }
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), e -> {
            try {
                if (ExceptionUtil.isOverloaded(e)) {
                    LOG.error("OpenSearch is overloaded");
                    this.setCoolDownStart();
                }
                this.nodeStateManager.setException(configId, (Exception)e);
            }
            finally {
                listener.onFailure(e);
            }
        });
        ((ModelColdStart)this.coldStarter).trainModel(coldStartRequest, configId, modelState, coldStartListener);
    }

    protected abstract ModelState<RCFModelType> createEmptyState(FeatureRequest var1, String var2, String var3);
}

