")
25 | os.Exit(1)
26 | }
27 | startIdx, err := strconv.Atoi(os.Args[1])
28 | if err != nil {
29 | fmt.Fprintln(os.Stderr, err)
30 | os.Exit(1)
31 | }
32 | endIdx, err := strconv.Atoi(os.Args[2])
33 | if err != nil {
34 | fmt.Fprintln(os.Stderr, err)
35 | os.Exit(1)
36 | }
37 | pasteIndices := make(chan int)
38 | go func() {
39 | for i := startIdx; i <= endIdx; i++ {
40 | pasteIndices <- i
41 | }
42 | close(pasteIndices)
43 | }()
44 | outDir := os.Args[3]
45 | if err := ensureDirectoryPresent(outDir); err != nil {
46 | fmt.Fprintln(os.Stderr, err)
47 | os.Exit(1)
48 | }
49 | fetchPastes(pasteIndices, outDir)
50 | }
51 |
52 | func fetchPastes(indices <-chan int, outDir string) {
53 | var wg sync.WaitGroup
54 | for i := 0; i < RoutineCount; i++ {
55 | wg.Add(1)
56 | go func() {
57 | defer wg.Done()
58 | for index := range indices {
59 | if err := fetchPaste(index, outDir); err != nil {
60 | fmt.Fprintln(os.Stderr, "error for paste", index, err)
61 | } else {
62 | fmt.Println("succeeded for paste", index)
63 | }
64 | }
65 | }()
66 | }
67 | wg.Wait()
68 | }
69 |
70 | func fetchPaste(index int, outDir string) error {
71 | code, lang, err := fetchPasteCode(index)
72 | if err != nil {
73 | return err
74 | }
75 | codeDir := filepath.Join(outDir, lang)
76 | ensureDirectoryPresent(codeDir)
77 | fileName := strconv.Itoa(index) + ".txt"
78 | return ioutil.WriteFile(filepath.Join(codeDir, fileName), []byte(code), 0755)
79 | }
80 |
81 | func fetchPasteCode(index int) (contents, language string, err error) {
82 | response, err := http.Get("http://pastie.org/pastes/" + strconv.Itoa(index))
83 | if err != nil {
84 | return
85 | }
86 | body, err := ioutil.ReadAll(response.Body)
87 | response.Body.Close()
88 | if err != nil {
89 | return
90 | }
91 | pageData := string(body)
92 | exp := regexp.MustCompile("\\\n(.*?)\n ")
93 | match := exp.FindStringSubmatch(pageData)
94 | if match == nil {
95 | ioutil.WriteFile("/Users/alex/Desktop/foo.html", []byte(pageData), 0755)
96 | return "", "", errors.New("cannot locate language")
97 | }
98 | language = match[1]
99 | language = strings.Replace(language, "/", ":", -1)
100 |
101 | response, err = http.Get("http://pastie.org/pastes/" + strconv.Itoa(index) + "/text")
102 | if err != nil {
103 | return
104 | }
105 | root, err := html.Parse(response.Body)
106 | response.Body.Close()
107 | if err != nil {
108 | return
109 | }
110 | codeBlock, ok := scrape.Find(root, scrape.ByTag(atom.Pre))
111 | if !ok {
112 | return "", "", errors.New("no
tag")
113 | }
114 | contents = codeBlockText(codeBlock)
115 | return
116 | }
117 |
118 | func codeBlockText(n *html.Node) string {
119 | if n.DataAtom == atom.Br {
120 | return "\n"
121 | }
122 | if n.Type == html.TextNode {
123 | return n.Data
124 | }
125 |
126 | var res string
127 | child := n.FirstChild
128 | for child != nil {
129 | res += codeBlockText(child)
130 | child = child.NextSibling
131 | }
132 | return res
133 | }
134 |
135 | func ensureDirectoryPresent(dirPath string) error {
136 | if _, err := os.Stat(dirPath); err != nil {
137 | if err := os.Mkdir(dirPath, 0755); err != nil {
138 | return err
139 | }
140 | }
141 | return nil
142 | }
143 |
--------------------------------------------------------------------------------
/neuralnet/classifier.go:
--------------------------------------------------------------------------------
1 | package neuralnet
2 |
3 | import (
4 | "encoding/json"
5 | "math"
6 |
7 | "github.com/unixpickle/num-analysis/kahan"
8 | "github.com/unixpickle/whichlang/tokens"
9 | )
10 |
11 | // A Network is a feedforward neural network with
12 | // a single hidden layer.
13 | type Network struct {
14 | Tokens []string
15 | Langs []string
16 |
17 | // In the following weights, the last weight for
18 | // each neuron corresponds to a constant shift,
19 | // and is not multiplied by an input's value.
20 | HiddenWeights [][]float64
21 | OutputWeights [][]float64
22 |
23 | // Information used to centralize the training
24 | // weights around 0 and get them to have a
25 | // standard deviation of 1.
26 | InputShift float64
27 | InputScale float64
28 | }
29 |
30 | func DecodeNetwork(data []byte) (*Network, error) {
31 | var n Network
32 | if err := json.Unmarshal(data, &n); err != nil {
33 | return nil, err
34 | }
35 | return &n, nil
36 | }
37 |
38 | func (n *Network) Copy() *Network {
39 | res := &Network{
40 | Tokens: make([]string, len(n.Tokens)),
41 | Langs: make([]string, len(n.Langs)),
42 | HiddenWeights: make([][]float64, len(n.HiddenWeights)),
43 | OutputWeights: make([][]float64, len(n.OutputWeights)),
44 | InputShift: n.InputShift,
45 | InputScale: n.InputScale,
46 | }
47 | copy(res.Tokens, n.Tokens)
48 | copy(res.Langs, n.Langs)
49 | for i, w := range n.HiddenWeights {
50 | res.HiddenWeights[i] = make([]float64, len(w))
51 | copy(res.HiddenWeights[i], w)
52 | }
53 | for i, w := range n.OutputWeights {
54 | res.OutputWeights[i] = make([]float64, len(w))
55 | copy(res.OutputWeights[i], w)
56 | }
57 | return res
58 | }
59 |
60 | func (n *Network) Classify(freqs tokens.Freqs) string {
61 | inputs := n.shiftedInput(freqs)
62 |
63 | outputSums := make([]*kahan.Summer64, len(n.OutputWeights))
64 | for i := range outputSums {
65 | outputSums[i] = kahan.NewSummer64()
66 | outputSums[i].Add(n.outputBias(i))
67 | }
68 |
69 | for hiddenIndex, hiddenWeights := range n.HiddenWeights {
70 | hiddenSum := kahan.NewSummer64()
71 | for j, input := range inputs {
72 | hiddenSum.Add(input * hiddenWeights[j])
73 | }
74 | hiddenSum.Add(n.hiddenBias(hiddenIndex))
75 |
76 | hiddenOut := sigmoid(hiddenSum.Sum())
77 | for j, outSum := range outputSums {
78 | weight := n.OutputWeights[j][hiddenIndex]
79 | outSum.Add(weight * hiddenOut)
80 | }
81 | }
82 |
83 | maxSum := outputSums[0].Sum()
84 | maxIdx := 0
85 | for i, x := range outputSums {
86 | if x.Sum() > maxSum {
87 | maxSum = x.Sum()
88 | maxIdx = i
89 | }
90 | }
91 | return n.Langs[maxIdx]
92 | }
93 |
94 | func (n *Network) Encode() []byte {
95 | enc, _ := json.Marshal(n)
96 | return enc
97 | }
98 |
99 | func (n *Network) Languages() []string {
100 | return n.Langs
101 | }
102 |
103 | func (n *Network) outputBias(outputIdx int) float64 {
104 | return n.OutputWeights[outputIdx][len(n.HiddenWeights)]
105 | }
106 |
107 | func (n *Network) hiddenBias(hiddenIdx int) float64 {
108 | return n.HiddenWeights[hiddenIdx][len(n.Tokens)]
109 | }
110 |
111 | func (n *Network) containsNaN() bool {
112 | for _, wss := range [][][]float64{n.HiddenWeights, n.OutputWeights} {
113 | for _, ws := range wss {
114 | for _, w := range ws {
115 | if math.IsNaN(w) {
116 | return true
117 | }
118 | }
119 | }
120 | }
121 | return false
122 | }
123 |
124 | func (n *Network) shiftedInput(f tokens.Freqs) []float64 {
125 | res := make([]float64, len(n.Tokens))
126 | for i, word := range n.Tokens {
127 | res[i] = (f[word] + n.InputShift) * n.InputScale
128 | }
129 | return res
130 | }
131 |
132 | func sigmoid(x float64) float64 {
133 | return 1.0 / (1.0 + math.Exp(-x))
134 | }
135 |
--------------------------------------------------------------------------------
/tokens/sample_counts.go:
--------------------------------------------------------------------------------
1 | package tokens
2 |
3 | import (
4 | "io/ioutil"
5 | "os"
6 | "path/filepath"
7 | "sort"
8 | "strings"
9 | )
10 |
11 | // SampleCounts maps programming languages to
12 | // arrays of language sample documents, where
13 | // each sample is represented by a Counts.
14 | type SampleCounts map[string][]Counts
15 |
16 | // ReadSampleCounts computes token counts
17 | // for programming language samples in a
18 | // directory.
19 | //
20 | // The directory should contain sub-directories
21 | // for each programming language, and each of
22 | // these languages should contain one or more
23 | // source files.
24 | //
25 | // The returned map maps language names to lists
26 | // of Counts, where each Counts corresponds to
27 | // one source file.
28 | func ReadSampleCounts(sampleDir string) (SampleCounts, error) {
29 | languages, err := readDirectory(sampleDir, true)
30 | if err != nil {
31 | return nil, err
32 | }
33 |
34 | res := SampleCounts{}
35 | for _, language := range languages {
36 | langDir := filepath.Join(sampleDir, language)
37 | files, err := readDirectory(langDir, false)
38 | if err != nil {
39 | return nil, err
40 | }
41 | for _, file := range files {
42 | contents, err := ioutil.ReadFile(filepath.Join(langDir, file))
43 | if err != nil {
44 | return nil, err
45 | }
46 | counts := CountTokens(string(contents))
47 | res[language] = append(res[language], counts)
48 | }
49 | }
50 |
51 | return res, nil
52 | }
53 |
54 | // NumTokens returns the number of unique
55 | // tokens in all the documents.
56 | func (s SampleCounts) NumTokens() int {
57 | toks := map[string]bool{}
58 | for _, samples := range s {
59 | for _, sample := range samples {
60 | for word := range sample {
61 | toks[word] = true
62 | }
63 | }
64 | }
65 | return len(toks)
66 | }
67 |
68 | // Prune removes tokens which appear in n
69 | // documents or fewer.
70 | //
71 | // This creates a "" token in each document
72 | // corresponding to the number of pruned
73 | // tokens from that document.
74 | func (s SampleCounts) Prune(n int) {
75 | docCount := map[string]int{}
76 | for _, samples := range s {
77 | for _, sample := range samples {
78 | for word := range sample {
79 | docCount[word]++
80 | }
81 | }
82 | }
83 |
84 | remove := map[string]bool{}
85 | for word, count := range docCount {
86 | if count <= n {
87 | remove[word] = true
88 | }
89 | }
90 |
91 | for _, samples := range s {
92 | for i, sample := range samples {
93 | newSample := map[string]int{}
94 | removed := 0
95 | for word, count := range sample {
96 | if !remove[word] {
97 | newSample[word] = count
98 | } else {
99 | removed += count
100 | }
101 | }
102 | if removed > 0 {
103 | newSample[""] += removed
104 | }
105 | samples[i] = newSample
106 | }
107 | }
108 | }
109 |
110 | // SampleFreqs converts every Counts object
111 | // in s into a Freqs object.
112 | // The "" key in each Freqs object is deleted
113 | // if one exists.
114 | func (s SampleCounts) SampleFreqs() map[string][]Freqs {
115 | res := map[string][]Freqs{}
116 | for lang, samples := range s {
117 | for _, sample := range samples {
118 | f := sample.Freqs()
119 | delete(f, "")
120 | res[lang] = append(res[lang], f)
121 | }
122 | }
123 | return res
124 | }
125 |
126 | func readDirectory(dir string, isDir bool) ([]string, error) {
127 | f, err := os.Open(dir)
128 | if err != nil {
129 | return nil, err
130 | }
131 | defer f.Close()
132 | contents, err := f.Readdir(-1)
133 | if err != nil {
134 | return nil, err
135 | }
136 | res := make([]string, 0, len(contents))
137 | for _, info := range contents {
138 | if info.IsDir() == isDir && !strings.HasPrefix(info.Name(), ".") {
139 | res = append(res, info.Name())
140 | }
141 | }
142 | sort.Strings(res)
143 | return res, nil
144 | }
145 |
--------------------------------------------------------------------------------
/cmd/fetchlang/file_search.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "errors"
7 | "net/url"
8 | "path"
9 | "strings"
10 | )
11 |
12 | var (
13 | ErrNoResults = errors.New("no results")
14 | ErrMaxRequests = errors.New("too many requests")
15 | )
16 |
17 | // A FileSearch defines parameters for
18 | // searching a repository for files.
19 | type FileSearch struct {
20 | // Repository is the repository name,
21 | // formatted as "user/repo".
22 | Repository string
23 |
24 | MinFileSize int
25 | MaxFileSize int
26 | Extensions []string
27 |
28 | // MaxRequests is the maximum number of
29 | // API requests to be performed by the
30 | // search before giving up.
31 | MaxRequests int
32 | }
33 |
34 | // SearchFile runs a FileSearch.
35 | // It returns ErrMaxRequests if more than
36 | // s.MaxRequests requests are used.
37 | // It returns ErrNoResults if no results
38 | // are found.
39 | func (g *GithubClient) SearchFile(s FileSearch) (contents []byte, err error) {
40 | return g.firstFileSearch(&s, "/")
41 | }
42 |
43 | func (g *GithubClient) firstFileSearch(s *FileSearch, dir string) (match []byte, err error) {
44 | if s.MaxRequests == 0 {
45 | return nil, ErrMaxRequests
46 | }
47 |
48 | u := url.URL{
49 | Scheme: "https",
50 | Host: "api.github.com",
51 | Path: path.Join("/repos", s.Repository, "/contents", dir),
52 | }
53 |
54 | body, _, err := g.request(u.String())
55 | if err != nil {
56 | return nil, err
57 | }
58 |
59 | s.MaxRequests--
60 |
61 | var result []entity
62 | if err := json.Unmarshal(body, &result); err != nil {
63 | return nil, err
64 | }
65 |
66 | for _, ent := range result {
67 | if ent.Match(s) {
68 | return g.readFile(s.Repository, ent.Path)
69 | }
70 | }
71 |
72 | sourceDirectoryHeuristic(result, s.Repository)
73 |
74 | for _, ent := range result {
75 | if ent.Dir() {
76 | match, err = g.firstFileSearch(s, ent.Path)
77 | if match != nil || (err != nil && err != ErrNoResults) {
78 | return
79 | }
80 | }
81 | }
82 |
83 | return nil, ErrNoResults
84 | }
85 |
86 | func (g *GithubClient) readFile(repo, filePath string) ([]byte, error) {
87 | u := url.URL{
88 | Scheme: "https",
89 | Host: "api.github.com",
90 | Path: path.Join("/repos", repo, "/contents", filePath),
91 | }
92 | body, _, err := g.request(u.String())
93 | if err != nil {
94 | return nil, err
95 | }
96 |
97 | var result struct {
98 | Content string `json:"content"`
99 | Encoding string `json:"encoding"`
100 | }
101 | if err := json.Unmarshal(body, &result); err != nil {
102 | return nil, err
103 | }
104 |
105 | if result.Encoding == "base64" {
106 | return base64.StdEncoding.DecodeString(result.Content)
107 | } else {
108 | return nil, errors.New("unknown encoding: " + result.Encoding)
109 | }
110 | }
111 |
112 | type entity struct {
113 | Name string `json:"name"`
114 | Path string `json:"path"`
115 | Size int `json:"size"`
116 | Type string `json:"type"`
117 | }
118 |
119 | func (e *entity) Dir() bool {
120 | return e.Type == "dir"
121 | }
122 |
123 | func (e *entity) Match(s *FileSearch) bool {
124 | if e.Type != "file" {
125 | return false
126 | }
127 | if e.Size < s.MinFileSize || e.Size > s.MaxFileSize {
128 | return false
129 | }
130 | for _, ext := range s.Extensions {
131 | if strings.HasSuffix(e.Name, "."+ext) {
132 | return true
133 | }
134 | }
135 | return false
136 | }
137 |
138 | // sourceDirectoryHeuristic sorts a list of
139 | // entities so that the first ones are more
140 | // likely to contain source code.
141 | func sourceDirectoryHeuristic(results []entity, repoName string) {
142 | sourceDirs := []string{"src", repoName, "lib", "com", "org", "net", "css", "assets"}
143 | numFound := 0
144 | for _, sourceDir := range sourceDirs {
145 | for i, ent := range results[numFound:] {
146 | if ent.Dir() && ent.Name == sourceDir {
147 | results[numFound], results[i] = results[i], results[numFound]
148 | numFound++
149 | break
150 | }
151 | }
152 | }
153 | }
154 |
--------------------------------------------------------------------------------
/neuralnet/data_set.go:
--------------------------------------------------------------------------------
1 | package neuralnet
2 |
3 | import (
4 | "math"
5 | "math/rand"
6 | "sort"
7 |
8 | "github.com/unixpickle/num-analysis/kahan"
9 | "github.com/unixpickle/whichlang/tokens"
10 | )
11 |
12 | const ValidationFraction = 0.3
13 |
14 | // A DataSet is a set of data split into training
15 | // samples and validation samples.
16 | type DataSet struct {
17 | ValidationSamples map[string][]tokens.Freqs
18 | TrainingSamples map[string][]tokens.Freqs
19 | NormalTrainingSamples map[string][][]float64
20 |
21 | // These are statistical properties of the
22 | // training samples' frequency values.
23 | MeanFrequency float64
24 | FrequencyStddev float64
25 | }
26 |
27 | // NewDataSet creates a DataSet by randomly
28 | // partitioning some data samples into
29 | // validation and training samples.
30 | func NewDataSet(samples map[string][]tokens.Freqs) *DataSet {
31 | res := &DataSet{
32 | ValidationSamples: map[string][]tokens.Freqs{},
33 | TrainingSamples: map[string][]tokens.Freqs{},
34 | }
35 | for lang, langSamples := range samples {
36 | shuffled := make([]tokens.Freqs, len(langSamples))
37 | perm := rand.Perm(len(shuffled))
38 | for i, x := range perm {
39 | shuffled[i] = langSamples[x]
40 | }
41 |
42 | numValid := int(float64(len(langSamples)) * ValidationFraction)
43 | res.ValidationSamples[lang] = shuffled[:numValid]
44 | res.TrainingSamples[lang] = shuffled[numValid:]
45 | }
46 |
47 | res.computeStatistics()
48 | res.computeNormalSamples()
49 |
50 | return res
51 | }
52 |
53 | // CrossScore returns the fraction of withheld
54 | // samples the Network worked for.
55 | func (c *DataSet) CrossScore(n *Network) float64 {
56 | return scoreNetwork(n, c.ValidationSamples)
57 | }
58 |
59 | // TrainingScore returns the fraction of
60 | // training samples the Network worked for.
61 | func (c *DataSet) TrainingScore(n *Network) float64 {
62 | return scoreNetwork(n, c.TrainingSamples)
63 | }
64 |
65 | // Tokens returns all of the tokens from all
66 | // of the training samples.
67 | func (c *DataSet) Tokens() []string {
68 | toks := map[string]bool{}
69 | for _, samples := range c.TrainingSamples {
70 | for _, sample := range samples {
71 | for tok := range sample {
72 | toks[tok] = true
73 | }
74 | }
75 | }
76 |
77 | res := make([]string, 0, len(toks))
78 | for tok := range toks {
79 | res = append(res, tok)
80 | }
81 | sort.Strings(res)
82 | return res
83 | }
84 |
85 | // Langs returns all of the languages represented
86 | // by the training samples.
87 | func (c *DataSet) Langs() []string {
88 | res := make([]string, 0, len(c.TrainingSamples))
89 | for lang := range c.TrainingSamples {
90 | res = append(res, lang)
91 | }
92 | sort.Strings(res)
93 | return res
94 | }
95 |
96 | func (c *DataSet) computeStatistics() {
97 | tokens := c.Tokens()
98 |
99 | freqSum := kahan.NewSummer64()
100 | freqCount := 0
101 | for _, langSamples := range c.TrainingSamples {
102 | for _, sample := range langSamples {
103 | freqCount += len(tokens)
104 | for _, freq := range sample {
105 | freqSum.Add(freq)
106 | }
107 | }
108 | }
109 |
110 | c.MeanFrequency = freqSum.Sum() / float64(freqCount)
111 |
112 | variationSum := kahan.NewSummer64()
113 | for _, langSamples := range c.TrainingSamples {
114 | for _, sample := range langSamples {
115 | for _, token := range tokens {
116 | freq := sample[token]
117 | variationSum.Add(math.Pow(freq-c.MeanFrequency, 2))
118 | }
119 | }
120 | }
121 |
122 | c.FrequencyStddev = math.Sqrt(variationSum.Sum() / float64(freqCount))
123 | }
124 |
125 | func (c *DataSet) computeNormalSamples() {
126 | c.NormalTrainingSamples = map[string][][]float64{}
127 | tokens := c.Tokens()
128 |
129 | for lang, langSamples := range c.TrainingSamples {
130 | sampleList := make([][]float64, len(langSamples))
131 | for i, sample := range langSamples {
132 | sampleVec := make([]float64, len(tokens))
133 | for j, token := range tokens {
134 | sampleVec[j] = (sample[token] - c.MeanFrequency) / c.FrequencyStddev
135 | }
136 | sampleList[i] = sampleVec
137 | }
138 | c.NormalTrainingSamples[lang] = sampleList
139 | }
140 | }
141 |
142 | func scoreNetwork(n *Network, samples map[string][]tokens.Freqs) float64 {
143 | var totalRight int
144 | var total int
145 | for lang, langSamples := range samples {
146 | for _, sample := range langSamples {
147 | if n.Classify(sample) == lang {
148 | totalRight++
149 | }
150 | total++
151 | }
152 | }
153 | return float64(totalRight) / float64(total)
154 | }
155 |
--------------------------------------------------------------------------------
/svm/trainer_params.go:
--------------------------------------------------------------------------------
1 | package svm
2 |
3 | import (
4 | "errors"
5 | "os"
6 | "strconv"
7 | )
8 |
9 | const (
10 | defaultTradeoff = 1e-5
11 | defaultCrossValidationFraction = 0.3
12 | )
13 |
14 | var (
15 | defaultRBFParams = [][]float64{{1e-5}, {1e-4}, {1e-3}, {1e-2}, {1e-1}, {1e0}, {1e1}, {1e2}}
16 | defaultPolyPowers = []float64{2}
17 | defaultPolySums = []float64{0, 1}
18 | )
19 |
20 | // These environment variables specify
21 | // various parameters for the SVM trainer.
22 | const (
23 | // Set this to "1" to get verbose logs.
24 | VerboseEnvVar = "SVM_VERBOSE"
25 |
26 | // You may set this to "linear", "rbf", or
27 | // "polynomial".
28 | KernelEnvVar = "SVM_KERNEL"
29 |
30 | // The numerical constant used in the
31 | // RBF kernel.
32 | RBFParamEnvVar = "SVM_RBF_PARAM"
33 |
34 | // The degree parameter for polynomial kernels.
35 | PolyDegreeEnvVar = "SVM_POLY_DEGREE"
36 |
37 | // The summed term (before applying the exponential)
38 | // for polynomial kernels.
39 | PolySumEnvVar = "SVM_POLY_SUM"
40 |
41 | // The tradeoff between margin size and hinge loss.
42 | // The higher the tradeoff value, the greater the
43 | // margin size, but at the expense of correct
44 | // classifications.
45 | TradeoffEnvVar = "SVM_TRADEOFF"
46 |
47 | // The fraction (from 0-1) of samples which are
48 | // used for cross validation.
49 | CrossValidationEnvVar = "SVM_CROSS_VALIDATION"
50 | )
51 |
52 | // TrainerParams specifies parameters for the
53 | // SVM trainer.
54 | type TrainerParams struct {
55 | Verbose bool
56 | Kernels []*Kernel
57 | Tradeoff float64
58 |
59 | CrossValidation float64
60 | }
61 |
62 | // EnvTrainerParams generates TrainerParams
63 | // by reading environment variables.
64 | // If an environment variable is incorrectly
65 | // formatted, this returns an error.
66 | // When a variable is missing, a default value
67 | // or set of values will be used.
68 | func EnvTrainerParams() (*TrainerParams, error) {
69 | var res TrainerParams
70 | var err error
71 |
72 | if res.Tradeoff, err = envTradeoff(); err != nil {
73 | return nil, err
74 | }
75 | if res.CrossValidation, err = envCrossValidation(); err != nil {
76 | return nil, err
77 | }
78 | res.Verbose = (os.Getenv(VerboseEnvVar) == "1")
79 |
80 | kernTypes, err := envKernelTypes()
81 | if err != nil {
82 | return nil, err
83 | }
84 |
85 | for _, kernType := range kernTypes {
86 | params, err := envKernelParams(kernType)
87 | if err != nil {
88 | return nil, err
89 | }
90 | for _, param := range params {
91 | kernel := &Kernel{
92 | Type: kernType,
93 | Params: param,
94 | }
95 | res.Kernels = append(res.Kernels, kernel)
96 | }
97 | }
98 |
99 | return &res, nil
100 | }
101 |
102 | func envTradeoff() (float64, error) {
103 | if val := os.Getenv(TradeoffEnvVar); val != "" {
104 | return strconv.ParseFloat(val, 64)
105 | } else {
106 | return defaultTradeoff, nil
107 | }
108 | }
109 |
110 | func envCrossValidation() (float64, error) {
111 | if val := os.Getenv(CrossValidationEnvVar); val != "" {
112 | return strconv.ParseFloat(val, 64)
113 | } else {
114 | return defaultCrossValidationFraction, nil
115 | }
116 | }
117 |
118 | func envKernelTypes() ([]KernelType, error) {
119 | if val := os.Getenv(KernelEnvVar); val != "" {
120 | res, ok := map[string]KernelType{
121 | "linear": LinearKernel,
122 | "polynomial": PolynomialKernel,
123 | "rbf": RadialBasisKernel,
124 | }[val]
125 | if !ok {
126 | return nil, errors.New("unknown kernel: " + val)
127 | } else {
128 | return []KernelType{res}, nil
129 | }
130 | } else {
131 | return []KernelType{LinearKernel, PolynomialKernel, RadialBasisKernel}, nil
132 | }
133 | }
134 |
135 | func envKernelParams(t KernelType) ([][]float64, error) {
136 | switch t {
137 | case LinearKernel:
138 | return [][]float64{{}}, nil
139 | case RadialBasisKernel:
140 | if val := os.Getenv(RBFParamEnvVar); val != "" {
141 | res, err := strconv.ParseFloat(val, 64)
142 | if err != nil {
143 | return nil, errors.New("invalid RBF param: " + val)
144 | }
145 | return [][]float64{{res}}, nil
146 | } else {
147 | return defaultRBFParams, nil
148 | }
149 | case PolynomialKernel:
150 | powers := defaultPolyPowers
151 | sums := defaultPolySums
152 | if val := os.Getenv(PolySumEnvVar); val != "" {
153 | sum, err := strconv.ParseFloat(val, 64)
154 | if err != nil {
155 | return nil, errors.New("invalid poly sum: " + val)
156 | }
157 | sums = []float64{sum}
158 | }
159 | if val := os.Getenv(PolyDegreeEnvVar); val != "" {
160 | degree, err := strconv.ParseFloat(val, 64)
161 | if err != nil {
162 | return nil, errors.New("invalid poly degree: " + val)
163 | }
164 | powers = []float64{degree}
165 | }
166 | res := make([][]float64, 0, len(powers)*len(sums))
167 | for _, power := range powers {
168 | for _, sum := range sums {
169 | res = append(res, []float64{sum, power})
170 | }
171 | }
172 | return res, nil
173 | default:
174 | panic("unknown kernel: " + strconv.Itoa(int(t)))
175 | }
176 | }
177 |
--------------------------------------------------------------------------------
/neuralnet/train.go:
--------------------------------------------------------------------------------
1 | package neuralnet
2 |
3 | import (
4 | "log"
5 | "math/rand"
6 |
7 | "github.com/unixpickle/whichlang/tokens"
8 | )
9 |
10 | const InitialIterationCount = 200
11 |
12 | func Train(data map[string][]tokens.Freqs) *Network {
13 | ds := NewDataSet(data)
14 |
15 | var best *Network
16 | var bestCrossScore float64
17 | var bestTrainScore float64
18 |
19 | verbose := verboseFlag()
20 |
21 | for _, stepSize := range stepSizes() {
22 | if verbose {
23 | log.Printf("trying step size %f", stepSize)
24 | }
25 |
26 | t := NewTrainer(ds, stepSize, verbose)
27 | t.Train(maxIterations())
28 |
29 | n := t.Network()
30 | if n.containsNaN() {
31 | if verbose {
32 | log.Printf("got NaN for step size %f", stepSize)
33 | }
34 | continue
35 | }
36 | crossScore := ds.CrossScore(n)
37 | trainScore := ds.TrainingScore(n)
38 | if verbose {
39 | log.Printf("stepSize=%f crossScore=%f trainScore=%f", stepSize,
40 | crossScore, trainScore)
41 | }
42 | if (crossScore == bestCrossScore && trainScore >= bestTrainScore) ||
43 | best == nil || (crossScore > bestCrossScore) {
44 | bestCrossScore = crossScore
45 | bestTrainScore = trainScore
46 | best = n
47 | }
48 | }
49 |
50 | return best
51 | }
52 |
53 | type Trainer struct {
54 | n *Network
55 | d *DataSet
56 | g *gradientCalc
57 |
58 | stepSize float64
59 | verbose bool
60 | }
61 |
62 | func NewTrainer(d *DataSet, stepSize float64, verbose bool) *Trainer {
63 | hiddenCount := hiddenSize(len(d.TrainingSamples))
64 | n := &Network{
65 | Tokens: d.Tokens(),
66 | Langs: d.Langs(),
67 | HiddenWeights: make([][]float64, hiddenCount),
68 | OutputWeights: make([][]float64, len(d.TrainingSamples)),
69 | InputShift: -d.MeanFrequency,
70 | InputScale: 1 / d.FrequencyStddev,
71 | }
72 | for i := range n.OutputWeights {
73 | n.OutputWeights[i] = make([]float64, hiddenCount+1)
74 | for j := range n.OutputWeights[i] {
75 | n.OutputWeights[i][j] = rand.Float64()*2 - 1
76 | }
77 | }
78 | for i := range n.HiddenWeights {
79 | n.HiddenWeights[i] = make([]float64, len(n.Tokens)+1)
80 | for j := range n.HiddenWeights[i] {
81 | n.HiddenWeights[i][j] = rand.Float64()*2 - 1
82 | }
83 | }
84 | return &Trainer{
85 | n: n,
86 | d: d,
87 | g: newGradientCalc(n),
88 | stepSize: stepSize,
89 | verbose: verbose,
90 | }
91 | }
92 |
93 | func (t *Trainer) Train(maxIters int) {
94 | iters := InitialIterationCount
95 | if iters > maxIters {
96 | iters = maxIters
97 | }
98 | for i := 0; i < iters; i++ {
99 | if verboseStepsFlag() {
100 | log.Printf("done %d iterations, cross=%f training=%f",
101 | i, t.d.CrossScore(t.n), t.d.TrainingScore(t.n))
102 | }
103 | t.runAllSamples()
104 | }
105 | if iters == maxIters {
106 | return
107 | }
108 |
109 | if t.n.containsNaN() {
110 | return
111 | }
112 |
113 | // Use cross-validation to find the best
114 | // number of iterations.
115 | crossScore := t.d.CrossScore(t.n)
116 | trainScore := t.d.TrainingScore(t.n)
117 | lastNet := t.n.Copy()
118 |
119 | for {
120 | if t.verbose {
121 | log.Printf("current scores: cross=%f train=%f iters=%d",
122 | crossScore, trainScore, iters)
123 | }
124 |
125 | nextAmount := iters
126 | if nextAmount+iters > maxIters {
127 | nextAmount = maxIters - iters
128 | }
129 | for i := 0; i < nextAmount; i++ {
130 | if verboseStepsFlag() {
131 | log.Printf("done %d iterations, cross=%f training=%f",
132 | i+iters, t.d.CrossScore(t.n), t.d.TrainingScore(t.n))
133 | }
134 | t.runAllSamples()
135 | if t.n.containsNaN() {
136 | break
137 | }
138 | }
139 | iters += nextAmount
140 |
141 | if t.n.containsNaN() {
142 | t.n = lastNet
143 | break
144 | }
145 |
146 | newCrossScore := t.d.CrossScore(t.n)
147 | newTrainScore := t.d.TrainingScore(t.n)
148 | if (newCrossScore == crossScore && newTrainScore == trainScore) ||
149 | newCrossScore < crossScore {
150 | t.n = lastNet
151 | return
152 | }
153 |
154 | crossScore = newCrossScore
155 | trainScore = newTrainScore
156 |
157 | if iters == maxIters {
158 | return
159 | }
160 | lastNet = t.n.Copy()
161 | }
162 | }
163 |
164 | func (t *Trainer) Network() *Network {
165 | return t.n
166 | }
167 |
168 | func (t *Trainer) runAllSamples() {
169 | var samples []struct {
170 | LangIdx int
171 | Sample []float64
172 | }
173 |
174 | for i, lang := range t.n.Langs {
175 | var sample struct {
176 | LangIdx int
177 | Sample []float64
178 | }
179 | sample.LangIdx = i
180 |
181 | trainingSamples := t.d.NormalTrainingSamples[lang]
182 | for _, s := range trainingSamples {
183 | sample.Sample = s
184 | samples = append(samples, sample)
185 | }
186 | }
187 |
188 | perm := rand.Perm(len(samples))
189 | for _, i := range perm {
190 | t.descendSample(samples[i].Sample, samples[i].LangIdx)
191 | }
192 | }
193 |
194 | // descendSample performs gradient descent to
195 | // reduce the output error for a given sample.
196 | func (t *Trainer) descendSample(inputs []float64, langIdx int) {
197 | t.g.Compute(inputs, langIdx)
198 |
199 | for i, partials := range t.g.HiddenPartials {
200 | for j, partial := range partials {
201 | t.n.HiddenWeights[i][j] -= partial * t.stepSize
202 | }
203 | }
204 | for i, partials := range t.g.OutputPartials {
205 | for j, partial := range partials {
206 | t.n.OutputWeights[i][j] -= partial * t.stepSize
207 | }
208 | }
209 | }
210 |
--------------------------------------------------------------------------------
/idtree/train.go:
--------------------------------------------------------------------------------
1 | package idtree
2 |
3 | import (
4 | "math"
5 | "runtime"
6 | "sort"
7 |
8 | "github.com/unixpickle/whichlang/tokens"
9 | )
10 |
11 | type splitInfo struct {
12 | TokenIdx int
13 | Threshold float64
14 | Entropy float64
15 | }
16 |
17 | // Train returns a *Classifier which is the
18 | // result of running ID3 on a set of training
19 | // samples.
20 | func Train(freqs map[string][]tokens.Freqs) *Classifier {
21 | toks := allTokens(freqs)
22 | samples := freqsToLinearSamples(toks, freqs)
23 | return generateClassifier(toks, samples)
24 | }
25 |
26 | func allTokens(freqs map[string][]tokens.Freqs) []string {
27 | words := make([]string, 0)
28 | seenWords := map[string]bool{}
29 | for _, freqsList := range freqs {
30 | for _, freqs := range freqsList {
31 | for word := range freqs {
32 | if !seenWords[word] {
33 | seenWords[word] = true
34 | words = append(words, word)
35 | }
36 | }
37 | }
38 | }
39 | return words
40 | }
41 |
42 | // generateClassifier generates a classifier
43 | // for the given set of samples.
44 | func generateClassifier(toks []string, s []linearSample) *Classifier {
45 | tokIdx, thresh := bestDecision(s)
46 | if tokIdx == -1 {
47 | lang := languageMajority(s)
48 | return &Classifier{
49 | LeafClassification: &lang,
50 | }
51 | }
52 | res := &Classifier{
53 | Keyword: toks[tokIdx],
54 | Threshold: thresh,
55 | }
56 | f, t := splitData(s, tokIdx, thresh)
57 | res.FalseBranch = generateClassifier(toks, f)
58 | res.TrueBranch = generateClassifier(toks, t)
59 | return res
60 | }
61 |
62 | func splitData(s []linearSample, tokIdx int, thresh float64) (f, t []linearSample) {
63 | f = make([]linearSample, 0, len(s))
64 | t = make([]linearSample, 0, len(s))
65 |
66 | for _, sample := range s {
67 | if sample.freqs[tokIdx] > thresh {
68 | t = append(t, sample)
69 | } else {
70 | f = append(f, sample)
71 | }
72 | }
73 |
74 | return
75 | }
76 |
77 | // bestDecision returns the token and threshold
78 | // which split the samples optimally (by the
79 | // criterion of entropy).
80 | // If no split exists, this returns (-1, -1).
81 | func bestDecision(s []linearSample) (tokIdx int, thresh float64) {
82 | if len(s) == 0 {
83 | return -1, -1
84 | }
85 |
86 | maxProcs := runtime.GOMAXPROCS(0)
87 | tokenCount := len(s[0].freqs)
88 |
89 | toksPerGo := tokenCount / maxProcs
90 | splitChan := make(chan *splitInfo, maxProcs)
91 | for i := 0; i < maxProcs; i++ {
92 | tokCount := toksPerGo
93 | tokStart := toksPerGo * i
94 |
95 | // The last set might need to be slightly larger
96 | // due to division truncation.
97 | if i == maxProcs-1 {
98 | tokCount = tokenCount - tokStart
99 | }
100 |
101 | go bestNodeSubset(tokStart, tokCount, s, splitChan)
102 | }
103 |
104 | var best *splitInfo
105 | for i := 0; i < maxProcs; i++ {
106 | res := <-splitChan
107 | if res == nil {
108 | continue
109 | }
110 | if best == nil || res.Entropy < best.Entropy {
111 | best = res
112 | }
113 | }
114 |
115 | if best == nil {
116 | return -1, -1
117 | }
118 |
119 | return best.TokenIdx, best.Threshold
120 | }
121 |
122 | func bestNodeSubset(startIdx, count int, s []linearSample, res chan<- *splitInfo) {
123 | bestThresh := -1.0
124 | var bestEntropy float64
125 | var bestIdx int
126 | for i := 0; i < count; i++ {
127 | idx := startIdx + i
128 | thresh, entropy := bestSplit(s, idx)
129 | if thresh < 0 {
130 | continue
131 | } else if bestThresh < 0 || entropy < bestEntropy {
132 | bestEntropy = entropy
133 | bestThresh = thresh
134 | bestIdx = idx
135 | }
136 | }
137 | if bestThresh == -1 {
138 | res <- nil
139 | } else {
140 | res <- &splitInfo{bestIdx, bestThresh, bestEntropy}
141 | }
142 | }
143 |
144 | // bestSplit finds the ideal threshold for splitting
145 | // samples by a given token (specified by an index).
146 | // This returns the threshold and the resulting entropy.
147 | // The threshold will be -1 if no split is useful.
148 | func bestSplit(unsorted []linearSample, tokenIdx int) (thresh float64, entrop float64) {
149 | samples := make([]linearSample, len(unsorted))
150 | copy(samples, unsorted)
151 | sorter := &sampleSorter{samples, tokenIdx}
152 | sort.Sort(sorter)
153 |
154 | lowerDistribution := map[string]int{}
155 | upperDistribution := map[string]int{}
156 |
157 | for _, sample := range samples {
158 | upperDistribution[sample.lang]++
159 | }
160 |
161 | if len(upperDistribution) == 1 {
162 | // Can't split homogeneous data effectively.
163 | return -1, -1
164 | }
165 |
166 | thresh = -1
167 | entrop = -1
168 |
169 | if len(samples) == 0 {
170 | return
171 | }
172 |
173 | lastFreq := samples[0].freqs[tokenIdx]
174 | for i := 1; i < len(samples); i++ {
175 | upperDistribution[samples[i-1].lang]--
176 | lowerDistribution[samples[i-1].lang]++
177 |
178 | freq := samples[i].freqs[tokenIdx]
179 | if freq == lastFreq {
180 | continue
181 | }
182 |
183 | upperFrac := float64(len(samples)-i) / float64(len(samples))
184 | lowerFrac := float64(i) / float64(len(samples))
185 | disorder := upperFrac*distributionEntropy(upperDistribution) +
186 | lowerFrac*distributionEntropy(lowerDistribution)
187 | if disorder < entrop || thresh == -1 {
188 | entrop = disorder
189 | thresh = (lastFreq + freq) / 2
190 | }
191 |
192 | lastFreq = freq
193 | }
194 |
195 | return
196 | }
197 |
198 | func distributionEntropy(dist map[string]int) float64 {
199 | var res float64
200 | var totalCount int
201 | for _, count := range dist {
202 | totalCount += count
203 | }
204 | for _, count := range dist {
205 | fraction := float64(count) / float64(totalCount)
206 | if fraction != 0 {
207 | res -= math.Log(fraction) * fraction
208 | }
209 | }
210 | return res
211 | }
212 |
--------------------------------------------------------------------------------
/svm/trainer.go:
--------------------------------------------------------------------------------
1 | package svm
2 |
3 | import (
4 | "log"
5 | "math/rand"
6 | "time"
7 |
8 | "github.com/unixpickle/num-analysis/linalg"
9 | "github.com/unixpickle/weakai/svm"
10 | "github.com/unixpickle/whichlang/tokens"
11 | )
12 |
13 | const farAwayTimeout = time.Hour * 24 * 365
14 |
15 | func Train(data map[string][]tokens.Freqs) *Classifier {
16 | params, err := EnvTrainerParams()
17 | if err != nil {
18 | panic(err)
19 | }
20 | return TrainParams(data, params)
21 | }
22 |
23 | func TrainParams(data map[string][]tokens.Freqs, p *TrainerParams) *Classifier {
24 | crossFreqs, trainingFreqs := partitionSamples(data, p.CrossValidation)
25 | tokens, samples := vectorizeSamples(trainingFreqs)
26 |
27 | solver := svm.GradientDescentSolver{
28 | Timeout: farAwayTimeout,
29 | Tradeoff: p.Tradeoff,
30 | }
31 |
32 | var bestClassifier *Classifier
33 | var bestValidationScore float64
34 |
35 | for _, kernel := range p.Kernels {
36 | if p.Verbose {
37 | log.Println("Trying kernel:", kernel)
38 | }
39 | solverKernel := cachedKernel(kernel)
40 | classifier := &Classifier{
41 | Keywords: tokens,
42 | Kernel: kernel,
43 | Classifiers: map[string]BinaryClassifier{},
44 | }
45 |
46 | usedSamples := map[int]linalg.Vector{}
47 | for lang := range samples {
48 | if p.Verbose {
49 | log.Println("Training classifier for language:", lang)
50 | }
51 | problem := svmProblem(samples, lang, solverKernel)
52 | solution := solver.Solve(problem)
53 | binClass := BinaryClassifier{
54 | SupportVectors: make([]int, len(solution.SupportVectors)),
55 | Weights: make([]float64, len(solution.Coefficients)),
56 | Threshold: -solution.Threshold,
57 | }
58 | copy(binClass.Weights, solution.Coefficients)
59 | for i, v := range solution.SupportVectors {
60 | // v.UserInfo will be turned into a support
61 | // vector index by makeSampleVectorList().
62 | binClass.SupportVectors[i] = v.UserInfo
63 | usedSamples[v.UserInfo] = linalg.Vector(v.V)
64 | }
65 | classifier.Classifiers[lang] = binClass
66 | }
67 |
68 | makeSampleVectorList(classifier, usedSamples)
69 |
70 | score := correctFraction(classifier, crossFreqs)
71 | if p.Verbose {
72 | trainingScore := correctFraction(classifier, trainingFreqs)
73 | log.Printf("Results: cross=%f training=%f support=%d/%d", score,
74 | trainingScore, len(classifier.SampleVectors), countSamples(samples))
75 | }
76 | if score > bestValidationScore || bestClassifier == nil {
77 | bestClassifier = classifier
78 | }
79 | }
80 |
81 | return bestClassifier
82 | }
83 |
84 | func partitionSamples(data map[string][]tokens.Freqs, crossFrac float64) (cross,
85 | training map[string][]tokens.Freqs) {
86 |
87 | cross = map[string][]tokens.Freqs{}
88 | training = map[string][]tokens.Freqs{}
89 |
90 | for lang, samples := range data {
91 | p := rand.Perm(len(samples))
92 | newSamples := make([]tokens.Freqs, len(samples))
93 | for i, x := range p {
94 | newSamples[i] = samples[x]
95 | }
96 | crossCount := int(crossFrac * float64(len(samples)))
97 | cross[lang] = newSamples[:crossCount]
98 | training[lang] = newSamples[crossCount:]
99 | }
100 |
101 | return
102 | }
103 |
104 | func vectorizeSamples(data map[string][]tokens.Freqs) ([]string, map[string][]svm.Sample) {
105 | seenToks := map[string]bool{}
106 | for _, samples := range data {
107 | for _, sample := range samples {
108 | for tok := range sample {
109 | seenToks[tok] = true
110 | }
111 | }
112 | }
113 | toks := make([]string, 0, len(seenToks))
114 | for tok := range seenToks {
115 | toks = append(toks, tok)
116 | }
117 |
118 | sampleMap := map[string][]svm.Sample{}
119 | sampleID := 1
120 | for lang, samples := range data {
121 | vecSamples := make([]svm.Sample, 0, len(samples))
122 | for _, sample := range samples {
123 | vec := make([]float64, len(toks))
124 | for i, tok := range toks {
125 | vec[i] = sample[tok]
126 | }
127 | svmSample := svm.Sample{
128 | V: vec,
129 | UserInfo: sampleID,
130 | }
131 | sampleID++
132 | vecSamples = append(vecSamples, svmSample)
133 | }
134 | sampleMap[lang] = vecSamples
135 | }
136 |
137 | return toks, sampleMap
138 | }
139 |
140 | func countSamples(s map[string][]svm.Sample) int {
141 | var count int
142 | for _, samples := range s {
143 | count += len(samples)
144 | }
145 | return count
146 | }
147 |
148 | func cachedKernel(k *Kernel) svm.Kernel {
149 | return svm.CachedKernel(func(s1, s2 svm.Sample) float64 {
150 | return k.Product(linalg.Vector(s1.V), linalg.Vector(s2.V))
151 | })
152 | }
153 |
154 | func svmProblem(data map[string][]svm.Sample, posLang string, k svm.Kernel) *svm.Problem {
155 | var positives, negatives []svm.Sample
156 | for lang, samples := range data {
157 | if lang == posLang {
158 | positives = append(positives, samples...)
159 | } else {
160 | negatives = append(negatives, samples...)
161 | }
162 | }
163 | return &svm.Problem{
164 | Positives: positives,
165 | Negatives: negatives,
166 | Kernel: k,
167 | }
168 | }
169 |
170 | func correctFraction(c *Classifier, data map[string][]tokens.Freqs) float64 {
171 | var correct, total int
172 | for lang, samples := range data {
173 | for _, sample := range samples {
174 | total++
175 | if c.Classify(sample) == lang {
176 | correct++
177 | }
178 | }
179 | }
180 | return float64(correct) / float64(total)
181 | }
182 |
183 | func makeSampleVectorList(c *Classifier, used map[int]linalg.Vector) {
184 | userInfoToVecIdx := map[int]int{}
185 |
186 | for userInfo, sample := range used {
187 | userInfoToVecIdx[userInfo] = len(c.SampleVectors)
188 | c.SampleVectors = append(c.SampleVectors, sample)
189 | }
190 |
191 | for _, binClass := range c.Classifiers {
192 | for i, userInfo := range binClass.SupportVectors {
193 | binClass.SupportVectors[i] = userInfoToVecIdx[userInfo]
194 | }
195 | }
196 | }
197 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # whichlang
2 |
3 | This is a suite of Machine Learning tools for identifying the language in which a piece of code is written. It could potentially be used for text editors, code hosting websites, and much more.
4 |
5 | A seasoned programmer could quickly tell you that this program is written in C:
6 |
7 | ```c
8 | #include
9 |
10 | int main(int argc, const char ** argv) {
11 | printf("Hello, world!");
12 | }
13 | ```
14 |
15 | The goal of `whichlang` is to teach a program to do the same. By showing a Machine Learning algorithm a ton of code, you can teach it *learn* to identify programming languages itself.
16 |
17 | # Usage
18 |
19 | There are four steps to using whichlang:
20 |
21 | * Configure Go and download whichlang.
22 | * Fetch code samples from Github or some other source.
23 | * Train a classifier with the code samples.
24 | * Use the whichlang API or server with the classifier you trained.
25 |
26 | ## Configuring Go and whichlang
27 |
28 | First, follow the instructions on [this page](https://golang.org/doc/install) to setup Go. Once Go is setup and you have a `GOPATH` configured, run this set of commands:
29 |
30 | ```
31 | $ go get github.com/unixpickle/whichlang
32 | $ cd $GOPATH/src/github.com/unixpickle/whichlang
33 | ```
34 |
35 | Now you have downloaded `whichlang` and are sitting in its root source folder.
36 |
37 | ## Fetching samples
38 |
39 | To fetch samples from Github, you must have a Github account (having more than one Github account may be beneficial, as well). You should decide how many samples you want for each programming language. I have found that 180 is more than enough.
40 |
41 | You can fetch samples and save them to a directory as follows:
42 |
43 | ```
44 | $ mkdir /path/to/samples
45 | $ go run cmd/fetchlang/*.go /path/to/samples 180
46 | ```
47 |
48 | In the above example, I specified 180 samples per language. This will prompt you for your Github credentials (to get around strict API rate limits). If you specify a large number of samples (where 180 counts as a large number), you may hit Github's API rate limits several times during the fetching process. If this occurs, you will want to delete the partially-downloaded source directories (they will be subdirectories of your sample directory, and will contain less than 180 samples), then wait an hour before re-running `fetchlang`. The `fetchlang` sub-command will automatically skip any source directories that are already present, making it relatively easy to resume paused or rate-limited downloads.
49 |
50 | ## Training a classifier
51 |
52 | With whichlang, you can train a number of different kinds of classifiers on your data. Currently, you can use the following classifiers:
53 |
54 | * [ID3](https://en.wikipedia.org/wiki/ID3)
55 | * [K-nearest neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
56 | * [Artificial Neural Networks](https://en.wikipedia.org/wiki/Artificial_neural_network)
57 | * [Support Vector Machines](https://en.wikipedia.org/wiki/Support_vector_machine)
58 |
59 | Out of these algorithms, I have found that Support Vector Machines are the simplest to train and work very well. Artificial Neural Networks are a close second, but they have more hyper-parameters and are thus harder to tune well. In this document, I will describe how to train both of these classifiers, leaving out ID3 and K-nearest neighbors.
60 |
61 | ### Choosing the "ubiquity"
62 |
63 | For any classifier you use, you must choose a "ubiquity" value. Since whichlang works by extracting keywords from source files, it is important to discern potentially important keywords from file-specific keywords like variable names or embedded strings. To do this, keywords which appear in less than `N` files are ignored during training and classification, where `N` is the "ubiquity". I have found that a ubiquity of 10-20 works when you have roughly 100 source files.
64 |
65 | ### Support Vector Machines
66 |
67 | The easiest way to train a Support Vector Machine is to allow whichlang to select all the hyper-parameters for you. Note, however, that this option is *very* slow, so you may want to keep reading.
68 |
69 | ```
70 | $ go run cmd/trainer/*.go svm 15 /path/to/samples /path/to/classifier.json
71 | ```
72 |
73 | In the above command, I specified a ubiquity of 15 files. This will train an SVM on the given sample directory, outputing a classifier to `/path/to/classifier.json`. As this command runs, it will go through many different possible SVM configurations, choosing the one which performs the best on new samples (as measured via cross-validation). Since this command has to try many possible configurations, it will take a long time to run (perhaps hours or even days). I have already gone through the trouble of finding good parameters, and I will now share my results.
74 |
75 | I have found that a linear SVM works fairly well for programming language classification. In particular, I've gotten a linear SVM to reach a 93% success rate on new samples, and most of those mistakes were reasonable (e.g. mistaking C for C++, or mistaking Ruby for CoffeeScript). To train a linear SVM, you can set the `SVM_KERNEL` environment variable before running the `trainer` sub-command:
76 |
77 | ```
78 | $ export SVM_KERNEL=linear
79 | ```
80 |
81 | If you want verbose output during training, you can specify another environment variable:
82 |
83 | ```
84 | $ export SVM_VERBOSE=1
85 | ```
86 |
87 | Once you have trained a linear SVM, you can perform a special compression step which will make the classifier faster and smaller. This is a technique which only works for linear SVMs! Run the following command:
88 |
89 | ```
90 | $ go run cmd/svm-shrink/*.go /path/to/classifier.json /path/to/optimized.json
91 | ```
92 |
93 | This will create a classifier file at `/path/to/optimized.json` which is the optimized version of `/path/to/classifier.json`. **Remember, this only works for linear SVMs.**
94 |
95 | For other SVM environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/svm#pkg-constants).
96 |
97 | ### Artificial Neural Networks
98 |
99 | While whichlang does allow you to train ANNs without specifying any hyper-parameters (via grid search), doing so will take a tremendous amount of time. It is highly recommended that you manually specify the parameters for your neural network. I will give one example of training an ANN, but it is up to you to tweak these parameters:
100 |
101 | ```
102 | $ export NEURALNET_VERBOSE=1
103 | $ export NEURALNET_VERBOSE_STEPS=1
104 | $ export NEURALNET_STEP_SIZE=0.01
105 | $ export NEURALNET_MAX_ITERS=100
106 | $ export NEURALNET_HIDDEN_SIZE=150
107 | $ go run cmd/trainer/*.go neuralnet 15 /path/to/samples /path/to/classifier.json
108 | ```
109 |
110 | For more ANN environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/neuralnet#pkg-variables).
111 |
112 | ## Using a classifier
113 |
114 | Using a classifier is as simple as loading in a file. You can checkout the [classify command](https://github.com/unixpickle/whichlang/blob/master/cmd/classify/main.go) for a very simple (15-line) example.
115 |
--------------------------------------------------------------------------------