#include "vocab.hpp"
#include "pattern.hpp"
#include "encoder.hpp"
#include "bagb_encoder.hpp"
#include "weights.hpp"
#include "trainer.hpp"
#include "base_trainer.hpp"

#include <cstdlib>
#include <fstream>
#include <sstream>
#include <iostream>

int main(int argc,char*argv[])
{
  vocab v;
  v.load(argv[1]);
  encoder*enc=new bagb_encoder;
  enc->set_elements(v);
  pattern_pool pp=enc->encode_vocab_patterns(v);
  weights w(argv[2]);
  
  std::ifstream is(argv[3]);
  std::vector<std::string> conds,targets;
  std::vector<std::vector<std::string> > primes;
  std::string ln,it;
  std::getline(is,ln);
  std::stringstream ss(ln);
  ss>>it;
  while(ss>>it)
  {
    conds.push_back(it);
  }
  while(std::getline(is,ln))
  {
    std::stringstream ss2(ln);
    ss2>>it;
    targets.push_back(it);
    primes.push_back(std::vector<std::string>());
    while(ss2>>it)
    {
      primes.back().push_back(it);
    }
  }

  std::vector<std::vector<pattern> > prime_patterns;
  for(auto i=primes.begin();i!=primes.end();++i)
  {
    prime_patterns.push_back(enc->encode_strings(*i));
  }

  std::vector<double> sums(conds.size()),ns(conds.size());

  std::ofstream of(argv[4]);
  for(auto i=conds.begin();i!=conds.end();++i)
  {
    of<<"\t"<<*i;
  }
  of<<"\n";
  auto j=prime_patterns.begin();
  for(auto i=targets.begin();i!=targets.end();++i,++j)
  {
    of<<*i;
    auto u=sums.begin(),k=ns.begin();
    for(auto l=j->begin();l!=j->end();++l,++u,++k)
    {
      double z=w.selective_predict(*l,v.which(*i));
      if(!std::isnan(z)){++(*k);*u+=z;}
      of<<"\t"<<z;
    }
    of<<"\n";
  }

  for(int i=0;i<conds.size()+1;++i)
    of<<"=============";
  of<<"\n";

  for(auto i=conds.begin();i!=conds.end();++i)
  {
    of<<"\t"<<(*i);
  }
  of<<"\n";
  for(auto l=sums.begin(),k=ns.begin();l!=sums.end();++l,++k)
  {
    of<<"\t"<<(*l)/(*k);
  }
  of<<"\n";
  
  delete enc;
}
