import { DeploymentStatus, Dictionary, IAccessOptions, ModelMode, ProblemType, ProblemTypeKeys, TClassMapping } from 'common';
import dayjs from 'dayjs';
import { IModel, IModelTrainSpecs } from './IModel';
import { TOutputSchema } from './IModelOutputSchema';

export interface IModelCardResponsePayload {
    model_id: string;
    model_architecture: string;
    model_framework: string;
    framework_version: string;
    reqs: string;
    preprocess: string;
    postprocess: string;
    model_version: number;
    model_url: string;
    model_size: number;
    class_mappings: TClassMapping;
    problem: string;
    category: string;
    tags?: string[];
    model_train_specs: IModelTrainSpecsPayload;
    vendor_id: string;
    model_last_updated: number;
    created_timestamp: number;
    last_updated_timestamp: number;
    loader: string;
    modality: string;
    model_name: string;
    anatomy: string;
    problem_type: ProblemTypeKeys;
    compatible_dataset_ids: Array<string>;
    last_validation_date: number;
    last_deployed_date: number;
    model_summary?: string;
    model_detail?: string;
    model_performance?: Dictionary<number>;
    training_dataset_stats?: any;
    model_access_options?: IAccessOptions;
    model_picture?: string;
    deployment_status?: DeploymentStatus;
    deployment_id?: string;
    archived?: boolean;
    dockerhub?: boolean;
    view_count?: number;
    is_xai?: boolean;
    output_mapping?: Dictionary<TOutputSchema>;
    file_tree?: Dictionary<any>;
    mode: ModelMode;
    interactive?: boolean;
    intraction?: string;
    experimental?: boolean;
    deployment?: boolean;
    train_job_id?: string;
    linked_model?: string;
}

export interface IModelTrainSpecsPayload {
    training_dataset_stats: Array<Dictionary<any>>;
    validation_metrics: Dictionary<number>;
    loss_distribution: Dictionary<number>;
    class_weights: Record<number, number>;
    batch_size: number;
    num_epochs: number;
    num_iterations: number;
    learning_rate: number;
    optimizer: string;
    loss_function: string;
    dataset_size: number;
}

export function getModelCardData(data: IModelCardResponsePayload): IModel {
    if (!data) return {} as any;
    return {
        modelId: data.model_id,
        architecture: data.model_architecture,
        framework: data.model_framework,
        frameworkVersion: data.framework_version,
        reqs: data.reqs,
        preprocess: data.preprocess,
        postprocess: data.postprocess,
        version: data.model_version,
        url: data.model_url,
        size: data.model_size,
        classMappings: data.class_mappings,
        category: data.category,
        tags: data.tags || [],
        vendorId: data.vendor_id,
        lastUpdateDate: getDateFromTimeStamp(data.last_updated_timestamp),
        createdDate: getDateFromTimeStamp(data.created_timestamp),
        loader: data.loader,
        modality: data.modality,
        modelName: data.model_name,
        anatomy: data.anatomy,
        modelTrainSpecs: getModelTrainSpecs(data.model_train_specs),
        problemType: ProblemType[data.problem_type],
        compatibleDatasets: data.compatible_dataset_ids,
        lastDeployedDate: getDateFromTimeStamp(data.last_deployed_date),
        lastValidationDate: getDateFromTimeStamp(data.last_validation_date),
        modelSummary: data.model_summary,
        modelDetail: data.model_detail,
        modelPerformance: data.model_performance,
        trainingDatasetStats: data.training_dataset_stats,
        modelAccessOptions: data.model_access_options,
        modelPicture: data.model_picture,
        deploymentId: data.deployment_id,
        deploymentStatus: data.deployment_status,
        archived: data.archived || false,
        dockerhub: data.dockerhub || false,
        viewCount: data.view_count,
        isXai: data.is_xai || null,
        outputMapping: data.output_mapping ?? {},
        file_tree: data.file_tree ?? null,
        mode: data.mode,
        interactive: data.interactive || false,
        intraction: data.intraction,
        experimental: data.experimental,
        deployment: data.deployment,
        trainJobId: data.train_job_id,
        linkedModel: data.linked_model,
    };
}

function getDateFromTimeStamp(date: number): string {
    return date ? dayjs.unix(date).format('MM.DD.YYYY') : '';
}

export function getModelTrainSpecs(data: IModelTrainSpecsPayload): IModelTrainSpecs {
    return {
        validationMetrics: data?.validation_metrics,
        lossDistribution: data?.loss_distribution,
        classWeights: data?.class_weights,
        batchSize: data?.batch_size,
        numEpochs: data?.num_epochs,
        numIterations: data?.num_iterations,
        learningRate: data?.learning_rate,
        optimizer: data?.optimizer,
        lossFunction: data?.loss_function,
        datasetSize: data?.dataset_size,
    };
}
