subset.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import os, sys, math, random
00004 from collections import defaultdict
00005 
00006 if sys.version_info[0] >= 3:
00007         xrange = range
00008 
00009 def exit_with_help(argv):
00010         print("""\
00011 Usage: {0} [options] dataset subset_size [output1] [output2]
00012 
00013 This script randomly selects a subset of the dataset.
00014 
00015 options:
00016 -s method : method of selection (default 0)
00017      0 -- stratified selection (classification only)
00018      1 -- random selection
00019 
00020 output1 : the subset (optional)
00021 output2 : rest of the data (optional)
00022 If output1 is omitted, the subset will be printed on the screen.""".format(argv[0]))
00023         exit(1)
00024 
00025 def process_options(argv):
00026         argc = len(argv)
00027         if argc < 3:
00028                 exit_with_help(argv)
00029 
00030         # default method is stratified selection
00031         method = 0  
00032         subset_file = sys.stdout
00033         rest_file = None
00034 
00035         i = 1
00036         while i < argc:
00037                 if argv[i][0] != "-":
00038                         break
00039                 if argv[i] == "-s":
00040                         i = i + 1
00041                         method = int(argv[i])
00042                         if method not in [0,1]:
00043                                 print("Unknown selection method {0}".format(method))
00044                                 exit_with_help(argv)
00045                 i = i + 1
00046 
00047         dataset = argv[i]
00048         subset_size = int(argv[i+1])
00049         if i+2 < argc:
00050                 subset_file = open(argv[i+2],'w')
00051         if i+3 < argc:
00052                 rest_file = open(argv[i+3],'w')
00053 
00054         return dataset, subset_size, method, subset_file, rest_file
00055 
00056 def random_selection(dataset, subset_size):
00057         l = sum(1 for line in open(dataset,'r'))
00058         return sorted(random.sample(xrange(l), subset_size))
00059 
00060 def stratified_selection(dataset, subset_size):
00061         labels = [line.split(None,1)[0] for line in open(dataset)]
00062         label_linenums = defaultdict(list)
00063         for i, label in enumerate(labels):
00064                 label_linenums[label] += [i]
00065 
00066         l = len(labels)
00067         remaining = subset_size
00068         ret = []
00069 
00070         # classes with fewer data are sampled first; otherwise
00071         # some rare classes may not be selected
00072         for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])):
00073                 linenums = label_linenums[label]
00074                 label_size = len(linenums) 
00075                 # at least one instance per class
00076                 s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l)))))
00077                 if s == 0:
00078                         sys.stderr.write('''\
00079 Error: failed to have at least one instance per class
00080     1. You may have regression data.
00081     2. Your classification data is unbalanced or too small.
00082 Please use -s 1.
00083 ''')
00084                         sys.exit(-1)
00085                 remaining -= s
00086                 ret += [linenums[i] for i in random.sample(xrange(label_size), s)]
00087         return sorted(ret)
00088 
00089 def main(argv=sys.argv):
00090         dataset, subset_size, method, subset_file, rest_file = process_options(argv)
00091         #uncomment the following line to fix the random seed 
00092         #random.seed(0)
00093         selected_lines = []
00094 
00095         if method == 0:
00096                 selected_lines = stratified_selection(dataset, subset_size)
00097         elif method == 1:
00098                 selected_lines = random_selection(dataset, subset_size)
00099 
00100         #select instances based on selected_lines
00101         dataset = open(dataset,'r')
00102         prev_selected_linenum = -1
00103         for i in xrange(len(selected_lines)):
00104                 for cnt in xrange(selected_lines[i]-prev_selected_linenum-1):
00105                         line = dataset.readline()
00106                         if rest_file: 
00107                                 rest_file.write(line)
00108                 subset_file.write(dataset.readline())
00109                 prev_selected_linenum = selected_lines[i]
00110         subset_file.close()
00111 
00112         if rest_file:
00113                 for line in dataset: 
00114                         rest_file.write(line)
00115                 rest_file.close()
00116         dataset.close()
00117 
00118 if __name__ == '__main__':
00119         main(sys.argv)
00120 


ml_classifiers
Author(s): Scott Niekum
autogenerated on Fri Jan 3 2014 11:30:23