logLevelHandler = new Fn() {
17 | public Object apply(String input) {
18 | try {
19 | Level logLevel = Level.toLevel(input);
20 | return logLevel;
21 | } catch (Exception e) {
22 | e.printStackTrace();
23 | System.exit(0);
24 | }
25 | return null;
26 | }};
27 |
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/exec/package.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | To use Execution framework, in the main
of you entry point, the
4 | first line should be
5 |
6 |
7 | edu.umass.nlp.exec.Execution.init(pathToConfigFile);
8 |
9 | The pathToConfigFile
should be a string to a
10 | yaml file that stores global options. Global options are used to populate
11 | objects with public mutable fields using reflection. The purpose of this
12 | is to provide easy (but not secure or robust) option management.
13 |
14 | The Execution framework does many thing, here's a summmary
15 |
16 | Global Option Configuration
17 |
18 | Options are grouped together in a hierarchichal fashion. For instance, the
19 | main Execution option would be given in yaml by
20 |
21 |
22 |
23 | exec:
24 | execPoolDir: execs
25 | loggerPattern: %-5p [%c]: %m%n
26 |
27 |
28 |
29 | So in your code when you call:
30 |
31 |
32 | Execution.Opts execOpts = Execution.fillOptions ("exec", new Execution.Opts());
33 |
34 |
35 |
36 | We reflectively look up options under the exec
part of the configuration
37 | file and fills in fields of the passed in objects. You can fill in different instances
38 | of the same options object by using different group names (e.g. exec1, exec2, ...
).
39 |
40 | You can also do hierarchical option filling.
41 |
42 | Store Execution Log and Options in a Directory
43 |
44 | Another feature of the Execution framework is that the log of every run goes to a directory
45 | specified in the exec.execDir
directory. Typically, you shouldn't specify the
46 | directory and instead use the option exec.execPoolDir
which will automatically
47 | make a new directory for each execution run by adding 0.exec,1.exec,2.exec,...
48 | as needed. The directory will store everything sent to the logger as well as a copy of configuration
49 | needed to re-run the experiment (modulo code changes obviously).
50 |
51 | If you want to store other output in the execution directory, you have access to it in
52 | Execution.getExecutionDirectry
.
53 |
54 | You can add option processing behavior using
55 | Execution.addOptionHandler
. See
56 | StandardOptionHandlers
for example
57 | option handlers.
58 |
59 | Apache Logger Configuration
60 |
61 | The Execution framework also configures the log4j logger (the pattern
62 | for the logger prefix is configurable via the exec.loggerPattern
option
63 | in your global config file.
64 |
65 | This means in your code you should probably not use System.out.println
and
66 | opt instead to use the logger. You can read about the Apache Logger system here .
67 |
68 | Two relevant configurable options are: Execution.Opts.loggerPattern
69 | and Execution.Opts.logLevel
.
70 |
71 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/CallbackFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface CallbackFn {
4 | public void callback(Object... args);
5 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/Double2DoubleFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface Double2DoubleFn {
4 | public double valAt(double x);
5 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/DoubleFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface DoubleFn {
4 | public double valAt(T x);
5 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/FactoryFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface FactoryFn {
4 |
5 | public T make();
6 |
7 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/Fn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | import java.io.Serializable;
4 |
5 | public interface Fn extends Serializable {
6 | public O apply(I input);
7 |
8 | public static class ConstantFn implements Fn
9 | {
10 |
11 | private O c;
12 |
13 | public ConstantFn(O c) {
14 | this.c = c;
15 | }
16 |
17 | public O apply(I input) {
18 | return c;
19 | }
20 | }
21 |
22 | public static class IdentityFn implements Fn
23 | {
24 |
25 | public I apply(I input)
26 | {
27 | return input;
28 | }
29 | }
30 |
31 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/Functional.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | import edu.umass.nlp.ml.sequence.CRF;
4 | import edu.umass.nlp.ml.sequence.ILabeledSeqDatum;
5 | import edu.umass.nlp.utils.*;
6 |
7 | import java.util.*;
8 |
9 |
10 | /**
11 | * Collection of Functional Utilities you'd
12 | * find in any functional programming language.
13 | * Things like map, filter, reduce, etc..
14 | *
15 | */
16 | public class Functional {
17 |
18 |
19 | public static List take(Iterator it, int n) {
20 | List result = new ArrayList();
21 | for (int i=0; i < n && it.hasNext(); ++i) {
22 | result.add(it.next());
23 | }
24 | return result;
25 | }
26 |
27 |
28 | public static IValued findMax(Iterable xs, DoubleFn fn) {
29 | double max = Double.NEGATIVE_INFINITY;
30 | T argMax = null;
31 | for (T x : xs) {
32 | double val = fn.valAt(x);
33 | if (val > max) { max = val ; argMax = x; }
34 | }
35 | return BasicValued.make(argMax,max);
36 | }
37 |
38 | public static IValued findMin(Iterable xs, Fn fn) {
39 | double min= Double.POSITIVE_INFINITY;
40 | T argMin = null;
41 | for (T x : xs) {
42 | double val = fn.apply(x);
43 | if (val < min) { min= val ; argMin = x; }
44 | }
45 | return BasicValued.make(argMin,min);
46 | }
47 |
48 | public static Map map(Map map, Fn fn, PredFn pred, Map resultMap) {
49 | for (Map.Entry entry: map.entrySet()) {
50 | K key = entry.getKey();
51 | I inter = entry.getValue();
52 | if (pred.holdsAt(key)) resultMap.put(key, fn.apply(inter));
53 | }
54 | return resultMap;
55 | }
56 |
57 | public static Map mapPairs(Iterable lst, Fn fn)
58 | {
59 | return mapPairs(lst,fn,new HashMap());
60 | }
61 |
62 | public static Map mapPairs(Iterable lst, Fn fn, Map resultMap)
63 | {
64 | for (I input: lst) {
65 | O output = fn.apply(input);
66 | resultMap.put(input,output);
67 | }
68 | return resultMap;
69 | }
70 |
71 | public static List map(Iterable lst, Fn fn) {
72 | return map(lst,fn,(PredFn) PredFns.getTruePredicate());
73 | }
74 |
75 | public static Iterator map(final Iterator it, final Fn fn) {
76 | return new Iterator() {
77 | public boolean hasNext() {
78 | return it.hasNext();
79 | }
80 |
81 | public O next() {
82 | return fn.apply(it.next());
83 | }
84 |
85 | public void remove() {
86 | throw new RuntimeException("remove() not supported");
87 | }
88 | };
89 | }
90 |
91 | public static Map makeMap(Iterable elems, Fn fn, Map map) {
92 | for (I elem : elems) {
93 | map.put(elem, fn.apply(elem));
94 | }
95 | return map;
96 | }
97 |
98 | public static Map makeMap(Iterable elems, Fn fn) {
99 | return makeMap(elems, fn, new HashMap()) ;
100 | }
101 |
102 | public static List flatMap(Iterable lst,
103 | Fn> fn) {
104 | PredFn> p = PredFns.getTruePredicate();
105 | return flatMap(lst,fn,p);
106 | }
107 |
108 |
109 | public static List flatMap(Iterable lst,
110 | Fn> fn,
111 | PredFn> pred) {
112 | List> lstOfLsts = map(lst,fn,pred);
113 | List init = new ArrayList();
114 | return reduce(lstOfLsts, init,
115 | new Fn, List>, List>() {
116 | public List apply(IPair, List> input) {
117 | List result = input.getFirst();
118 | result.addAll(input.getSecond());
119 | return result;
120 | }
121 | });
122 | }
123 |
124 | public static O reduce(Iterable inputs,
125 | O initial,
126 | Fn,O> fn) {
127 | O output = initial;
128 | for (I input: inputs) {
129 | output = fn.apply(BasicPair.make(output,input));
130 | }
131 | return output;
132 | }
133 |
134 | public static List map(Iterable lst, Fn fn, PredFn pred) {
135 | List outputs = new ArrayList();
136 | for (I input: lst) {
137 | O output = fn.apply(input);
138 | if (pred.holdsAt(output)) {
139 | outputs.add(output);
140 | }
141 | }
142 | return outputs;
143 | }
144 |
145 | public static List filter(final Iterable lst, final PredFn pred) {
146 | List ret = new ArrayList();
147 | for (I input : lst) {
148 | if (pred.holdsAt(input)) ret.add(input);
149 | }
150 | return ret;
151 | }
152 |
153 |
154 | public static T first(Iterable objs, PredFn pred) {
155 | for (T obj : objs) {
156 | if (pred.holdsAt(obj)) return obj;
157 | }
158 | return null;
159 | }
160 |
161 |
162 |
163 | public static List range(int n) {
164 | List result = new ArrayList();
165 | for (int i = 0; i < n; i++) {
166 | result.add(i);
167 | }
168 | return result;
169 | }
170 |
171 | /**
172 | *
173 | * @return
174 | */
175 | public static boolean any(Iterable elems, PredFn p) {
176 | for (T elem : elems) {
177 | if (p.holdsAt(elem)) return true;
178 | }
179 | return false;
180 | }
181 |
182 | public static boolean all(Iterable elems, PredFn p) {
183 | for (T elem : elems) {
184 | if (!p.holdsAt(elem)) return false;
185 | }
186 | return true;
187 | }
188 |
189 |
190 | public static T find(Iterable elems, PredFn pred) {
191 | return first(elems, pred);
192 | }
193 |
194 | public static int findIndex(Iterable elems, PredFn pred) {
195 | int index = 0;
196 | for (T elem : elems) {
197 | if (pred.holdsAt(elem)) return index;
198 | index += 1;
199 | }
200 | return -1;
201 | }
202 |
203 | public static List indicesWhere(Iterable elems, PredFn pred) {
204 | List res = new ArrayList();
205 | int index = 0;
206 | for (T elem : elems) {
207 | if (pred.holdsAt(elem)) {
208 | res.add(index);
209 | }
210 | index ++;
211 | }
212 | return res;
213 | }
214 |
215 | public static String mkString(Iterable elems, String start, String middle, String stop) {
216 | return mkString(elems, start, middle, stop, null);
217 | }
218 |
219 | public static String mkString(Iterable elems, String start, String middle, String stop,Fn strFn) {
220 | StringBuilder sb = new StringBuilder();
221 | sb.append(start);
222 | Iterator it = elems.iterator();
223 | while (it.hasNext()) {
224 | T t = it.next();
225 | sb.append((strFn != null ? strFn.apply(t) : t.toString()));
226 | if (it.hasNext()) {
227 | sb.append(middle);
228 | }
229 | }
230 | sb.append(stop);
231 | return sb.toString();
232 | }
233 |
234 | public static String mkString(Iterable elems) {
235 | return mkString(elems,"(",",",")",null);
236 | }
237 |
238 | public static List takeWhile(Iterable elems, PredFn pred) {
239 | Iterator it = elems.iterator();
240 | return takeWhile(it,pred);
241 | }
242 |
243 | public static List takeWhile(Iterator it, PredFn pred) {
244 | List res = new ArrayList();
245 | while (it.hasNext()) {
246 | T elem = it.next();
247 | if (pred.holdsAt(elem)) res.add(elem);
248 | else break;
249 | }
250 | return res;
251 | }
252 |
253 | public static List
254 | rangesWhere(Iterable elems, PredFn pred) {
255 | int index = 0;
256 | int lastStart = -1;
257 | List res = new ArrayList();
258 | for (T elem: elems) {
259 | boolean matches = pred.holdsAt(elem);
260 | if (matches && lastStart < 0) {
261 | lastStart = index;
262 | }
263 | if (!matches && lastStart >= 0) {
264 | res.add(new Span(lastStart,index));
265 | lastStart = -1;
266 | }
267 | index += 1;
268 | }
269 | if (lastStart >= 0) {
270 | res.add(new Span(lastStart, index));
271 | }
272 | return res;
273 | }
274 |
275 | public static List> subseqsWhere(List elems, PredFn pred) {
276 | List ranges = rangesWhere(elems, pred);
277 | List> res = new ArrayList>();
278 | for (Span span: ranges) {
279 | res.add(new ArrayList(elems.subList(span.getStart(), span.getStop())));
280 | }
281 | return res;
282 | }
283 |
284 | public static FactoryFn curry(final Fn fn, final A fixed) {
285 | return new FactoryFn() {
286 | public R make() {
287 | return fn.apply(fixed);
288 | }
289 | };
290 | }
291 |
292 | public static Iterable lazyMap(final Iterable xs, final Fn fn) {
293 | return new Iterable() {
294 | public Iterator iterator() {
295 | return new Iterator() {
296 |
297 | private Iterator it = xs.iterator();
298 |
299 | public boolean hasNext() {
300 | return it.hasNext();
301 | }
302 |
303 | public O next() {
304 | return fn.apply(it.next());
305 | }
306 |
307 | public void remove() {
308 | throw new RuntimeException("Not Implemented");
309 | }
310 | };
311 | }
312 | };
313 | }
314 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/PredFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface PredFn {
4 |
5 | public boolean holdsAt(T elem);
6 |
7 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/PredFns.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public class PredFns {
4 |
5 | public static PredFn getTruePredicate() {
6 | return new PredFn() {
7 | public boolean holdsAt(T elem) {
8 | return true;
9 | }};
10 | }
11 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/functional/Vector2DoubleFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.functional;
2 |
3 | public interface Vector2DoubleFn {
4 | public double valAt(double[] x);
5 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/io/IOUtils.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.io;
2 |
3 |
4 | import edu.umass.nlp.functional.Functional;
5 |
6 | import java.io.*;
7 | import java.util.ArrayList;
8 | import java.util.Collections;
9 | import java.util.Iterator;
10 | import java.util.List;
11 | import java.util.zip.GZIPInputStream;
12 | import java.util.zip.ZipInputStream;
13 |
14 | public class IOUtils {
15 |
16 | public static Iterable lazyLines(final InputStream is) {
17 | try {
18 | return lazyLines(new InputStreamReader(is));
19 | } catch (Exception e) {
20 | e.printStackTrace();
21 | }
22 | throw new IllegalStateException();
23 | }
24 |
25 | public static Iterable lazyLines(final File path) {
26 | try {
27 | return lazyLines(new FileReader(path));
28 | } catch (Exception e) {
29 | e.printStackTrace();
30 | }
31 | throw new IllegalStateException();
32 | }
33 |
34 | public static Iterable lazyLines(final String path) {
35 | try {
36 | return lazyLines(new FileReader(path));
37 | } catch (Exception e) {
38 | e.printStackTrace();
39 | }
40 | throw new IllegalStateException();
41 | }
42 |
43 | public static Iterable lazyLines(final Reader reader) {
44 | return new Iterable() {
45 | public Iterator iterator() {
46 | final BufferedReader buffered = new BufferedReader(reader);
47 | return new Iterator() {
48 | private String nextLine;
49 | private boolean consumed = true;
50 |
51 | private void queue() {
52 | if (!consumed) return;
53 | try {
54 | nextLine = buffered.readLine();
55 | consumed = false;
56 | } catch (Exception e) {
57 | e.printStackTrace();
58 | System.exit(0);
59 | }
60 | }
61 |
62 | public boolean hasNext() {
63 | queue();
64 | return nextLine != null;
65 | }
66 |
67 | public String next() {
68 | queue();
69 | String ret = nextLine;
70 | consumed = true;
71 | return ret;
72 | }
73 |
74 | public void remove() {
75 | throw new RuntimeException("Not Implemented");
76 | }
77 | };
78 | }
79 | };
80 | }
81 |
82 | public static List lines(InputStream is) {
83 | return lines(new InputStreamReader(is));
84 | }
85 |
86 | public static List lines(String f) { return lines(new File(f)); }
87 |
88 | public static List lines(File f) {
89 | try {
90 | Reader r = new FileReader(f);
91 | return lines(r);
92 | } catch (Exception e) {
93 | e.printStackTrace();
94 | }
95 | return Collections.emptyList();
96 | }
97 |
98 | public static List lines(Reader r) {
99 | List res = new ArrayList();
100 | try {
101 | BufferedReader br = new BufferedReader(r);
102 | while (true) {
103 | String line = br.readLine();
104 | if (line == null) break;
105 | res.add(line);
106 | }
107 | } catch (Exception e) {
108 | e.printStackTrace();
109 | }
110 | return res;
111 | }
112 |
113 | public static Reader reader(File f) {
114 | try {
115 | return new FileReader(f);
116 | } catch (Exception e) {
117 | e.printStackTrace();
118 | }
119 | return null;
120 | }
121 |
122 | public static InputStream inputStream(String name) {
123 | try {
124 | InputStream is = new FileInputStream(name);
125 | if (name.endsWith(".gz")) return new GZIPInputStream(is);
126 | if (name.endsWith(".zip")) return new ZipInputStream(is);
127 | return is;
128 | } catch (Exception e) {
129 | e.printStackTrace();
130 | }
131 | return null;
132 | }
133 |
134 | public static Reader readerFromResource(String resourcePath) {
135 | return new InputStreamReader(ClassLoader.getSystemResourceAsStream(resourcePath));
136 | }
137 |
138 | public static List linesFromResource(String resourcePath) {
139 | return lines(readerFromResource(resourcePath));
140 | }
141 |
142 | public static boolean exists(String f) {
143 | return (new File(f)).exists();
144 | }
145 |
146 | public static boolean exists(File f) {
147 | return f.exists();
148 | }
149 |
150 | public static String changeExt(String path, String newExt) {
151 | if (!newExt.startsWith(".")) {
152 | newExt = "." + newExt;
153 | }
154 | return path.replaceAll("\\.[^.]+$",newExt);
155 | }
156 |
157 | public static String changeDir(String path, String newDir) {
158 | File f = new File(path);
159 | return (new File(newDir,f.getName())).getPath();
160 | }
161 |
162 | public static String text(InputStream is) {
163 | return Functional.mkString(lines(is),"","\n","");
164 | }
165 |
166 | public static String text(String path) {
167 | return text(new File(path));
168 | }
169 |
170 | public static String text(Reader r) {
171 | return Functional.mkString(lines(r),"","\n","");
172 | }
173 |
174 | public static String text(File f) {
175 | return Functional.mkString(lines(f),"","\n","");
176 | }
177 |
178 | public static void writeLines(File f, List lines) {
179 | try {
180 | PrintWriter writer = new PrintWriter(new FileWriter(f));
181 | for (String line : lines) {
182 | writer.println(line);
183 | }
184 | writer.flush();
185 | writer.close();
186 | } catch (Exception e) {
187 | e.printStackTrace();
188 | }
189 | }
190 |
191 | public static void writeLines(String f, List lines) {
192 | writeLines(new File(f), lines);
193 | }
194 |
195 | public static List readObjects(InputStream is) {
196 | List ret = new ArrayList();
197 | try {
198 | ObjectInputStream ois = new ObjectInputStream(is);
199 | while (true) {
200 | Object o = ois.readObject();
201 | if (o == null) break;
202 | ret.add(o);
203 | }
204 | } catch (Exception e) {
205 | e.printStackTrace();
206 | }
207 | return ret;
208 | }
209 |
210 | public static Object readObject(String path) {
211 | InputStream is = inputStream(path);
212 | return readObject(is);
213 | }
214 |
215 | public static Object readObject(InputStream is) {
216 | try {
217 | ObjectInputStream ois = new ObjectInputStream(is);
218 | return ois.readObject();
219 | } catch (Exception e) {
220 | e.printStackTrace();
221 | }
222 | return null;
223 | }
224 |
225 | public static List readObjects(String path) {
226 | InputStream is = inputStream(path);
227 | List ret = readObjects(is);
228 | try {
229 | is.close();
230 | } catch (Exception e) {
231 | e.printStackTrace();
232 | }
233 | return ret;
234 | }
235 |
236 | public static void writeObject(Object o, String path) {
237 | try {
238 | OutputStream os = new FileOutputStream(new File(path));
239 | ObjectOutputStream oos = new ObjectOutputStream(os);
240 | oos.writeObject(o);
241 | oos.close();
242 | } catch (Exception e) {
243 | e.printStackTrace();
244 | }
245 | }
246 |
247 | public static PrintWriter getPrintWriter(String path) {
248 | try {
249 | return new PrintWriter(new FileWriter(path));
250 | } catch (Exception e) {
251 | e.printStackTrace();
252 | return null;
253 | }
254 | }
255 |
256 | public static void copy(String src, String dest) {
257 | List lines = IOUtils.lines(src);
258 | IOUtils.writeLines(dest,lines);
259 | }
260 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/io/ZipUtils.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.io;
2 |
3 |
4 | import edu.umass.nlp.functional.CallbackFn;
5 | import edu.umass.nlp.functional.Fn;
6 |
7 | import java.io.IOException;
8 | import java.io.InputStream;
9 | import java.util.List;
10 | import java.util.zip.ZipEntry;
11 | import java.util.zip.ZipFile;
12 | import java.util.zip.ZipOutputStream;
13 |
14 | public class ZipUtils {
15 |
16 | public static ZipFile getZipFile(String name) {
17 | try {
18 | return new ZipFile(name);
19 | } catch (Exception e) {
20 | e.printStackTrace();
21 | }
22 | return null;
23 | }
24 |
25 | public static InputStream getEntryInputStream(ZipFile zf, String entryName) {
26 | try {
27 | return zf.getInputStream(zf.getEntry(entryName));
28 | } catch (Exception e) {
29 | e.printStackTrace();
30 | }
31 | return null;
32 | }
33 |
34 | public static List getEntryLines(ZipFile zf, String entryName) {
35 | try {
36 | return IOUtils.lines(getEntryInputStream(zf, entryName));
37 | } catch (Exception e) { e.printStackTrace(); }
38 | return null;
39 | }
40 |
41 | public static boolean entryExists(ZipFile zipFile, String entryName) {
42 | return zipFile.getEntry(entryName) != null;
43 | }
44 |
45 | public static void main(String[] args) {
46 | ZipFile root = ZipUtils.getZipFile(args[0]);
47 |
48 | }
49 |
50 | public static void doZipEntry(ZipOutputStream zos, String entryName, CallbackFn entryFn) {
51 | try {
52 | ZipEntry ze = new ZipEntry(entryName);
53 | zos.putNextEntry(ze);
54 | entryFn.callback();
55 | zos.closeEntry();
56 | } catch (IOException e) {
57 | e.printStackTrace();
58 | }
59 | }
60 |
61 | public static void print(ZipOutputStream zos, String text) {
62 | try {
63 | zos.write(text.getBytes());
64 | } catch (IOException e) {
65 | e.printStackTrace();
66 | }
67 | }
68 |
69 | public static void println(ZipOutputStream zos, String text) { print(zos, text + "\n"); }
70 | }
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/F1Stats.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml;
2 |
3 | import edu.umass.nlp.utils.IMergable;
4 |
5 | public class F1Stats implements IMergable {
6 |
7 | public int tp = 0, fp = 0, fn = 0;
8 | public final String label;
9 |
10 | public F1Stats(String label) {
11 | this.label = label;
12 | }
13 |
14 | public void merge(F1Stats other) {
15 | tp += other.tp;
16 | fp += other.fp;
17 | fn += other.fn;
18 | }
19 |
20 | public double getPrecision() {
21 | if (tp + fp > 0.0) {
22 | return (tp / (tp + fp + 0.0));
23 | } else {
24 | return 0.0;
25 | }
26 | }
27 |
28 | public double getRecall() {
29 | if (tp + fn > 0.0) {
30 | return (tp / (tp + fn + 0.0));
31 | } else {
32 | return 0.0;
33 | }
34 | }
35 |
36 | public double getFMeasure(double beta) {
37 | double p = getPrecision();
38 | double r = getRecall();
39 | if (p + r > 0.0) {
40 | return ((1+beta*beta)* p * r) / ((beta*beta)*p + r);
41 | } else {
42 | return 0.0;
43 | }
44 | }
45 |
46 | public void observe(String trueLabel, String guessLabel) {
47 | assert (label.equals(trueLabel) || label.equals(guessLabel));
48 | if (label.equals(trueLabel)) {
49 | if (trueLabel.equals(guessLabel)) {
50 | tp++;
51 | } else {
52 | fn++;
53 | }
54 | } else {
55 | fp++;
56 | }
57 | if (trueLabel.equals(label)) {
58 | tp++;
59 | } else if (label.equals(trueLabel)) {
60 | fn++;
61 | } else if (label.equals(guessLabel)) {
62 | fp++;
63 | }
64 | }
65 |
66 | public String toString() {
67 | return String.format("f1: %.3f f2: %.3f prec: %.3f recall: %.3f (tp: %d, fp: %d, fn: %d)",
68 | getFMeasure(1.0), getFMeasure(2.0), getPrecision(), getRecall(), tp, fp, fn);
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/LossFn.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml;
2 |
3 | public interface LossFn {
4 | public double getLoss(L trueLabel, L guessLabel);
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/LossFns.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml;
2 |
3 | import edu.umass.nlp.functional.Fn;
4 | import edu.umass.nlp.utils.Collections;
5 | import edu.umass.nlp.utils.ICounter;
6 | import edu.umass.nlp.utils.IPair;
7 | import edu.umass.nlp.utils.MapCounter;
8 |
9 | import java.util.Collection;
10 | import java.util.HashMap;
11 | import java.util.Map;
12 |
13 |
14 | public class LossFns {
15 |
16 | public static Map> compileLossFn(LossFn lossFn, Collection labels) {
17 | Map> res = new HashMap>();
18 | for (L label : labels) {
19 | for (L otherLabel : labels) {
20 | Collections.getMut(res, label, new MapCounter())
21 | .incCount(otherLabel, lossFn.getLoss(label, otherLabel));
22 | }
23 | }
24 | return res;
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/Regularizers.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml;
2 |
3 | import edu.umass.nlp.functional.Fn;
4 | import edu.umass.nlp.utils.BasicPair;
5 | import edu.umass.nlp.utils.IPair;
6 |
7 | public class Regularizers {
8 |
9 | public static Fn> getL2Regularizer(final double sigmaSq) {
10 | return new Fn>() {
11 | public IPair apply(double[] input) {
12 | double obj = 0.0;
13 | double[] grad = new double[input.length];
14 | for (int i = 0; i < input.length; ++i) {
15 | double w = input[i];
16 | obj += w * w / sigmaSq;
17 | grad[i] += 2 * w / sigmaSq;
18 | }
19 | return new BasicPair(obj, grad);
20 | }
21 | };
22 | }
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/classification/BasicClassifierDatum.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml.classification;
2 |
3 | import edu.umass.nlp.utils.IValued;
4 |
5 | import java.util.List;
6 |
7 | public class BasicClassifierDatum implements ClassifierDatum {
8 | private final List> preds;
9 |
10 | public BasicClassifierDatum(List> preds) {
11 | this.preds = preds;
12 | }
13 |
14 | @Override
15 | public List> getPredicates() {
16 | return preds;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/java/edu/umass/nlp/ml/classification/BasicLabeledClassifierDatum.java:
--------------------------------------------------------------------------------
1 | package edu.umass.nlp.ml.classification;
2 |
3 | import edu.umass.nlp.utils.BasicValued;
4 | import edu.umass.nlp.utils.IValued;
5 |
6 | import java.util.ArrayList;
7 | import java.util.List;
8 |
9 | public class BasicLabeledClassifierDatum implements LabeledClassifierDatum {
10 | private final List> preds;
11 | private final L label;
12 |
13 | public BasicLabeledClassifierDatum(List> preds, L label) {
14 | this.preds = preds;
15 | this.label = label;
16 | }
17 |
18 | @Override
19 | public L getTrueLabel() {
20 | return label;
21 | }
22 |
23 | @Override
24 | public List> getPredicates() {
25 | return preds;
26 | }
27 |
28 | /**
29 | *
30 | */
31 | public static LabeledClassifierDatum
32 | getBinaryDatum(L label,String... preds) {
33 | List> predPairs = new ArrayList