import { zipObject, zip } from 'lodash'
import { INTERCEPT_VARIABLE, PREDICTION_COLUMN } from '../constants/constants'

const dotProd = (
    variablesCoefficients: { [varaible: string]: number },
    values: { [variable: string]: number }
): number => {
    return Object.keys(variablesCoefficients)
        .map(variable => (variable === INTERCEPT_VARIABLE ? 1 : variablesCoefficients[variable]) * values[variable])
        .reduce((a, b) => a + b, 0)
}

export function doPrediction (
    variablesCoefficients: { [varaible: string]: number },
    values: { [variable: string]: number }
): number {
    const dp = dotProd(variablesCoefficients, values)
    return 1 / (1 + Math.exp(-dp))
}

function downloadPredictions (data: string) {
    const a = window.document.createElement('a')
    a.href = window.URL.createObjectURL(new Blob([data], { type: 'text/csv' }))
    a.download = 'predictions.csv'

    // Append anchor to body.
    document.body.appendChild(a)
    a.click()

    // Remove anchor from body
    document.body.removeChild(a)
}

export function bulkPredictions (variablesCoefficients: { [varaible: string]: number }, csvFileData: string): void {
    const [header, ...rows] = csvFileData.split('\n').filter(line => line.length > 0).map(line => line.split(',').map(col => col.trim()))
    const headerSet = new Set([...header, INTERCEPT_VARIABLE])
    const keys = [...Object.keys(variablesCoefficients)]
    // check that all necessary columns are there.
    keys.forEach(columnName => {
        if (!headerSet.has(columnName)) {
            throw new Error(`Column ${columnName} was not found in the file.`)
        }
    })

    const objRows = rows.map(row => {
        const numRow = row.map(col => Number(col))
        return { ...zipObject(keys, numRow), [INTERCEPT_VARIABLE]: 1 }
    })

    const predictions = objRows.map(objRow => doPrediction(variablesCoefficients, objRow))

    const newHeader = [...header, PREDICTION_COLUMN].join(',')
    const newRows = zip(rows, predictions).map(([row, prediction]) => ([...row as string[], prediction as number]).join(',')).join('\n')
    downloadPredictions(`${newHeader}\n${newRows}`)
}
