#!/usr/bin/perl
# The Missing Textutils, Ondrej Bojar, obo@cuni.cz
# http://www.cuni.cz/~obo/textutils
#
# 'nfold' performs n-fold cross launch of the supplied script on the given data.
#
# All input lines are loaded, shuffled and n-times split into test and
# training data lines. (The size of the test data is 1/n-th, all the rest
# is used as training.) The command is launched n times with %test replaced
# by a temporary file containing the test data and %train replaced with
# the temporary training datafile.
#
# $Id: nfold,v 1.8 2006/01/28 17:08:34 bojar Exp $
#

use Getopt::Long;
use File::Temp qw/ tmpnam /;
use strict;

my $folds = 10;
my $limit = 0; # use only this number of input lines (selected at random)
my $pivot = undef;
my $maxfolds = 0; # prepare N folds but evaluate only first maxfolds
my $help = 0;
my $testsize = 0; # fixed testsize off by default
GetOptions("n=i" => \$folds,
  "limit=i" => \$limit,
  "testsize=i" => \$testsize,
  "maxfolds=i" => \$maxfolds,
  "pivot=i" => \$pivot,
  "help" => \$help
);
$pivot -- if defined $pivot;
my $cmd = shift;
if ($help || !$cmd) {
  print STDERR "usage: nfold 'command \%train \%test'
  --n=10      ... number of folds to perform
  --pivot=i   ... use the i-th column of input to split dataset instead of 
                  random splitting to N folds
  --limit=N   ... use only N (random) lines of input for the cross-validation
  --maxfolds=i... prepare N folds but evaluate only using first i of them
  --testsize=N... sets the number of chunks so that contains N elements
";
  exit 1;
}

my @lines;
while (<>) {
  push @lines, $_;
}

if ($testsize) {
  # fixed testsize means that we divide to that each chunk is approx. of
  # the testsize size and we set maxfolds to the original number of chunks
  $maxfolds = $folds if $maxfolds == 0;
  $folds = (scalar @lines)/$testsize;
  print STDERR "Splitting into $folds chunks, using at most $maxfolds chunks.\n";
}

my $nr = 0;

my @chunks = ();
my $chunks_allocated = 0;
my %chunk_for_pivot;
my $ch = 0;
while (@lines) {
  $nr++;
  my $rnd = int(rand($#lines+1));
  my $e = $lines[$rnd];
  splice(@lines, $rnd, 1);
  if (defined $pivot) {
    # override the $ch
    my @line = split /\t/, $e;
    my $val = $line[$pivot];
    if (defined $chunk_for_pivot{$val}) {
      $ch = $chunk_for_pivot{$val};
    } else {
      $chunk_for_pivot{$val} = $chunks_allocated;
      $ch = $chunks_allocated;
      $chunks_allocated++;
    }
  }
  # print "pushing $e into chunk $ch\n";
  push @{$chunks[$ch]}, $e;
  $ch++;
  $ch = $ch % $folds;
  last if $limit && $nr >= $limit;
}

my %chunk_label;
if (defined $pivot) {
  %chunk_label = reverse %chunk_for_pivot;
} else {
  %chunk_label = map {($_,$_+1)} (0..$#chunks);
}

my $donefolds = 0;
foreach my $ch (0..$#chunks) {
   my ($testfh, $testfn) = tmpnam();
   my ($trainfh, $trainfn) = tmpnam();
   print $testfh join("", @{$chunks[$ch]});
   close $testfh;

  foreach my $otherch (0..$#chunks) {
    next if $otherch == $ch;
    print $trainfh join("", @{$chunks[$otherch]});
  }
  close $trainfh;

  my $usecmd = $cmd;
  $usecmd =~ s/\%test/$testfn/g;
  $usecmd =~ s/\%train/$trainfn/g;

  open CMD, "$usecmd 2>&1|" or die "Can't launch '$usecmd'";
  while (<CMD>) {
    print $chunk_label{$ch};
    print "\t$_";
  }
  close CMD;
  unlink($testfn);
  unlink($trainfn);

  $donefolds++;
  last if $maxfolds && $donefolds >= $maxfolds;
}
