Go to the documentation of this file.00001
00002 from sys import argv, exit, stdout, stderr
00003 from random import randint
00004
00005 method = 0
00006 global n
00007 global dataset_filename
00008 subset_filename = ""
00009 rest_filename = ""
00010
00011 def exit_with_help():
00012 print("""\
00013 Usage: %s [options] dataset number [output1] [output2]
00014
00015 This script selects a subset of the given dataset.
00016
00017 options:
00018 -s method : method of selection (default 0)
00019 0 -- stratified selection (classification only)
00020 1 -- random selection
00021
00022 output1 : the subset (optional)
00023 output2 : rest of the data (optional)
00024 If output1 is omitted, the subset will be printed on the screen.""" % argv[0])
00025 exit(1)
00026
00027 def process_options():
00028 global method, n
00029 global dataset_filename, subset_filename, rest_filename
00030
00031 argc = len(argv)
00032 if argc < 3:
00033 exit_with_help()
00034
00035 i = 1
00036 while i < len(argv):
00037 if argv[i][0] != "-":
00038 break
00039 if argv[i] == "-s":
00040 i = i + 1
00041 method = int(argv[i])
00042 if method < 0 or method > 1:
00043 print("Unknown selection method %d" % (method))
00044 exit_with_help()
00045 i = i + 1
00046
00047 dataset_filename = argv[i]
00048 n = int(argv[i+1])
00049 if i+2 < argc:
00050 subset_filename = argv[i+2]
00051 if i+3 < argc:
00052 rest_filename = argv[i+3]
00053
00054 def main():
00055 class Label:
00056 def __init__(self, label, index, selected):
00057 self.label = label
00058 self.index = index
00059 self.selected = selected
00060
00061 process_options()
00062
00063
00064 i = 0
00065 labels = []
00066 f = open(dataset_filename, 'r')
00067 for line in f:
00068 labels.append(Label(float((line.split())[0]), i, 0))
00069 i = i + 1
00070 f.close()
00071 l = i
00072
00073
00074 if subset_filename != "":
00075 file1 = open(subset_filename, 'w')
00076 else:
00077 file1 = stdout
00078 split = 0
00079 if rest_filename != "":
00080 split = 1
00081 file2 = open(rest_filename, 'w')
00082
00083
00084 warning = 0
00085 if method == 0:
00086 labels.sort(key = lambda x: x.label)
00087
00088 label_end = labels[l-1].label + 1
00089 labels.append(Label(label_end, l, 0))
00090
00091 begin = 0
00092 label = labels[begin].label
00093 for i in range(l+1):
00094 new_label = labels[i].label
00095 if new_label != label:
00096 nr_class = i - begin
00097 k = i*n//l - begin*n//l
00098
00099 if k == 0:
00100 k = 1
00101 warning = warning + 1
00102 for j in range(nr_class):
00103 if randint(0, nr_class-j-1) < k:
00104 labels[begin+j].selected = 1
00105 k = k - 1
00106 begin = i
00107 label = new_label
00108 elif method == 1:
00109 k = n
00110 for i in range(l):
00111 if randint(0,l-i-1) < k:
00112 labels[i].selected = 1
00113 k = k - 1
00114 i = i + 1
00115
00116
00117 i = 0
00118 if method == 0:
00119 labels.sort(key = lambda x: int(x.index))
00120
00121 f = open(dataset_filename, 'r')
00122 for line in f:
00123 if labels[i].selected == 1:
00124 file1.write(line)
00125 else:
00126 if split == 1:
00127 file2.write(line)
00128 i = i + 1
00129
00130 if warning > 0:
00131 stderr.write("""\
00132 Warning:
00133 1. You may have regression data. Please use -s 1.
00134 2. Classification data unbalanced or too small. We select at least 1 per class.
00135 The subset thus contains %d instances.
00136 """ % (n+warning))
00137
00138
00139 f.close()
00140
00141 file1.close()
00142
00143 if split == 1:
00144 file2.close()
00145
00146 main()