using DataFrames
using Semirings
using TensorFSTs
using CSV
using BenchmarkTools
using OpenFst
using Glob
using IterTools
using NaNStatistics

include("../../TensorFSTs.jl/lib/OpenFstConvert.jl")

pairs = [("dense","topology"),
         ("dense","charlm"),
         ("dense","random"),
         ("topology","dense")]

function compbench(compfunc, machineA, machineB, seconds)
	b = @benchmarkable $compfunc($machineA, $machineB)	    
    # tune!(b)
    t = run(b, samples=100, seconds=seconds, evals=1)
    t.times
end

function compbench2(compfunc, machineA, machineB, seconds)
    AM, Aa, Aw = tensorFST2SparseArray(machineA)
    BM, Ba, Bw = tensorFST2SparseArray(machineB)
    min_cd = min(AM.dims[4], BM.dims[3])
    AM = AM[:,:,:,1:min_cd]
    BM = BM[:,:,1:min_cd,:]
	b = @benchmarkable $compfunc($AM, $BM, $Aa, $Aw, $Ba, $Bw)
    # tune!(b)
    t = run(b, samples=100, seconds=seconds, evals=1)
    t.times
end

function renamepath(path)
    type = splitpath(path)[2]
    name = replace(splitpath(path)[3],".fst"=>"")
    type * "-" * name
end

function main()
    machinezoo_path = "../../MachineZoo.jl/"

    if !isdir(joinpath(machinezoo_path, "machines", "composition"))
        mkdir(joinpath(machinezoo_path, "machines", "composition"))
    end

    tseconds = 4
    oseconds = 1

    dfs = []
    for path in glob(joinpath(machinezoo_path,"machines/*/*/fstinfo.csv"))
        df = DataFrame(CSV.File(path));
        push!(dfs, df)
    end
    df = vcat(dfs...)
    df[!,"type"] = map(x -> splitpath(x)[2] ,df[!,"file"])

    if size(df)[1] == 0
        println("No machines found")
        exit()
    end

    # filter
    # df = df[  ((df[!,"# of arcs"].<=1000) .& (df[!,"type"].=="dense")) .| (df[!,"type"].=="charlm") ,:]
    df = df[  .!((df[!,"# of arcs"].>1000) .& (df[!,"type"].=="dense")),:]

    # println(df)
    # exit()

    results = []
    num = 0
    for pair in pairs
        dfx = df[df.type.==pair[1],:]
        dfy = df[df.type.==pair[2],:]
        for (p,types) in zip(product(dfx.file, dfy.file),product(dfx[!,"arc type"], dfy[!,"arc type"]) )
            if types[1] != types[2]
                continue
            end
            ofstA = OF.read(joinpath(machinezoo_path, p[1]))
            ofstB = OF.read(joinpath(machinezoo_path, p[2]))
            println(p)
            ofstC = OF.compose(ofstA, ofstB)
            if OF.numstates(ofstC) == 0  
                println("skipping")
                continue
            end
            p1 = renamepath(p[1])
            p2 = renamepath(p[2])
            num=num+1
            fileC = joinpath( "machines", "composition", "$(num)-$(p1)-x-$(p2).fst")
            OF.write(ofstC, joinpath(machinezoo_path,fileC))

            # otimes = compbench(OF.compose, ofstA, ofstB, oseconds)
            # println(" of: ",mean(otimes))
            times = Dict()

            times["ofst"] = compbench(OF.compose, ofstA, ofstB, oseconds)

            try        
                times["tfst"] = compbench(TF.fsmcompose, TF.SparseTensorFSM(ofstA), TF.SparseTensorFSM(ofstB), oseconds) 
            catch
                println("tf failed")
                times["tfst"]  = [NaN]
            end

            stats = Dict()
            stats[Symbol("fileA")] = p[1]
            stats[Symbol("fileB")] = p[2]
            stats[Symbol("fileC")] = fileC
            for (k,v) in times
                stats[Symbol("$(k)_min")] = nanminimum(v)
                stats[Symbol("$(k)_max")] = nanmaximum(v)
                stats[Symbol("$(k)_mean")] = nanmean(v)
                stats[Symbol("$(k)_std")] = nanstd(v)
                stats[Symbol("$(k)_len")] = length(filter(!isnan, v))
            end
            push!(results, NamedTuple(stats))
        end
    end

    if length(results) != 0	
        CSV.write("composition_benchmark.csv",  DataFrame(results))
    else
        println("No results")
    end
end

main()