├── tokens ├── freqs.go ├── counts_test.go ├── sample_counts_test.go ├── counts.go └── sample_counts.go ├── cmd ├── server │ ├── assets │ │ ├── index.html │ │ ├── style.css │ │ └── main.js │ └── main.go ├── classify │ └── main.go ├── rater │ ├── types.go │ ├── main.go │ └── rate.go ├── svm-shrink │ └── main.go ├── fetchlang │ ├── repo_search.go │ ├── github.go │ ├── main.go │ └── file_search.go ├── trainer │ └── main.go ├── subsamples │ └── main.go └── fetchlang-pastie │ └── main.go ├── idtree ├── classifier.go ├── samples.go └── train.go ├── gaussbayes ├── classifier.go └── train.go ├── svm ├── kernel.go ├── classifier.go ├── trainer_params.go └── trainer.go ├── neuralnet ├── env.go ├── gradients.go ├── classifier.go ├── data_set.go └── train.go ├── knn ├── classifier.go └── trainer.go ├── main.go └── README.md /tokens/freqs.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | // Freqs maps words to their frequencies. 4 | // The frequency for a token X equal to 5 | // the number of occurrences of X, divided 6 | // by the total number of tokens in the 7 | // document. 8 | type Freqs map[string]float64 9 | 10 | // Freqs converts word counts into a 11 | // frequency map. 12 | func (c Counts) Freqs() Freqs { 13 | var totalCount int 14 | for _, count := range c { 15 | totalCount += count 16 | } 17 | res := Freqs{} 18 | for word, count := range c { 19 | res[word] = float64(count) / float64(totalCount) 20 | } 21 | return res 22 | } 23 | -------------------------------------------------------------------------------- /cmd/server/assets/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | whichlang 8 | 9 | 10 | 11 | 12 | 13 |
14 | 15 |
16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /cmd/server/assets/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | text-align: center; 3 | font-family: sans-serif; 4 | } 5 | 6 | #text-input { 7 | resize: none; 8 | margin-top: 10px; 9 | height: 222px; 10 | box-sizing: border-box; 11 | padding: 5px 5px; 12 | border: 1px solid #375eaa; 13 | border-radius: 5px; 14 | font-size: 14px; 15 | background-color: #ffffd8; 16 | font-family: monospace; 17 | } 18 | 19 | #text-input:focus { 20 | outline: 0; 21 | } 22 | 23 | @media (max-width: 400px) { 24 | #text-input { 25 | width: calc(100% - 20px); 26 | } 27 | } 28 | 29 | @media (min-width: 400px) { 30 | #text-input { 31 | width: 360px; 32 | } 33 | } 34 | 35 | #classify-button { 36 | margin: 5px 0 0 0; 37 | width: 110px; 38 | height: 35px; 39 | border: 1px solid #375eaa; 40 | border-radius: 5px; 41 | background-color: #e0ebf5; 42 | color: black; 43 | font-size: 16px; 44 | cursor: pointer; 45 | } 46 | 47 | #classification { 48 | margin-top: 10px; 49 | } 50 | -------------------------------------------------------------------------------- /tokens/counts_test.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import "testing" 4 | 5 | func TestCountTokens(t *testing.T) { 6 | document := "Hello this is a Hello123\ttest\nhello hi is1 is123\nhi!" 7 | actual := CountTokens(document) 8 | expected := map[string]int{ 9 | "Hello123": 1, 10 | "is1": 1, 11 | "is123": 1, 12 | "Hello": 2, 13 | "this": 1, 14 | "is": 3, 15 | "a": 1, 16 | "123": 2, 17 | "test": 1, 18 | "hello": 1, 19 | "hi": 2, 20 | "1": 1, 21 | "!": 1, 22 | "hi!": 1, 23 | 24 | "\nHello": 1, 25 | "test\n": 1, 26 | "\nhello": 1, 27 | "is123\n": 1, 28 | "123\n": 1, 29 | "\nhi": 1, 30 | "\nhi!": 1, 31 | "hi!\n": 1, 32 | "!\n": 1, 33 | } 34 | 35 | for x, count := range expected { 36 | if actual[x] != count { 37 | t.Error("expected count", count, "for", x, "but got", actual[x]) 38 | } 39 | } 40 | 41 | for x := range actual { 42 | if expected[x] == 0 { 43 | t.Error("got unexpected token:", x) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /cmd/classify/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | 8 | "github.com/unixpickle/whichlang" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) != 4 { 14 | fmt.Fprintln(os.Stderr, "Usage: classify ") 15 | os.Exit(1) 16 | } 17 | 18 | decoder := whichlang.Decoders[os.Args[1]] 19 | if decoder == nil { 20 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1]) 21 | os.Exit(1) 22 | } 23 | 24 | classifierData, err := ioutil.ReadFile(os.Args[2]) 25 | if err != nil { 26 | fmt.Fprintln(os.Stderr, err) 27 | os.Exit(1) 28 | } 29 | 30 | classifier, err := decoder(classifierData) 31 | if err != nil { 32 | fmt.Fprintln(os.Stderr, "Failed to decode classifier:", err) 33 | os.Exit(1) 34 | } 35 | 36 | contents, err := ioutil.ReadFile(os.Args[3]) 37 | if err != nil { 38 | fmt.Fprintln(os.Stderr, err) 39 | os.Exit(1) 40 | } 41 | 42 | counts := tokens.CountTokens(string(contents)) 43 | freqs := counts.Freqs() 44 | language := classifier.Classify(freqs) 45 | fmt.Println("Classification:", language) 46 | } 47 | -------------------------------------------------------------------------------- /cmd/rater/types.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "sort" 4 | 5 | type Rating struct { 6 | Correct int 7 | Total int 8 | } 9 | 10 | func (r *Rating) Frac() float64 { 11 | return float64(r.Correct) / float64(r.Total) 12 | } 13 | 14 | type OverallRating struct { 15 | Rating 16 | LangRatings []*LangRating 17 | } 18 | 19 | func NewOverallRating(correct, total int, l []*LangRating) *OverallRating { 20 | sorter := ratingSorter(make([]*LangRating, len(l))) 21 | copy(sorter, l) 22 | sort.Sort(sorter) 23 | return &OverallRating{ 24 | Rating: Rating{ 25 | Correct: correct, 26 | Total: total, 27 | }, 28 | LangRatings: []*LangRating(sorter), 29 | } 30 | } 31 | 32 | func (o *OverallRating) LongestLangName() string { 33 | var longest string 34 | for _, lang := range o.LangRatings { 35 | if len(lang.Language) > len(longest) { 36 | longest = lang.Language 37 | } 38 | } 39 | return longest 40 | } 41 | 42 | type LangRating struct { 43 | Rating 44 | Language string 45 | } 46 | 47 | type ratingSorter []*LangRating 48 | 49 | func (r ratingSorter) Len() int { 50 | return len(r) 51 | } 52 | 53 | func (r ratingSorter) Less(i, j int) bool { 54 | return r[i].Frac() > r[j].Frac() 55 | } 56 | 57 | func (r ratingSorter) Swap(i, j int) { 58 | r[i], r[j] = r[j], r[i] 59 | } 60 | -------------------------------------------------------------------------------- /cmd/server/assets/main.js: -------------------------------------------------------------------------------- 1 | (function() { 2 | 3 | var textInput; 4 | var request = null; 5 | var classificationLabel; 6 | 7 | function classify() { 8 | classificationLabel.style.display = 'block'; 9 | classificationLabel.innerText = 'Loading...'; 10 | if (request !== null) { 11 | request.abort(); 12 | } 13 | request = new XMLHttpRequest(); 14 | request.onreadystatechange = function() { 15 | if (request.readyState === 4) { 16 | handleClassification(request.responseText); 17 | request = null; 18 | } 19 | }; 20 | var time = new Date().getTime(); 21 | request.open('POST', '/classify?time=' + time, true); 22 | request.send(textInput.value); 23 | } 24 | 25 | function handleClassification(classification) { 26 | var obj = JSON.parse(classification); 27 | 28 | classificationLabel.innerText = 'Classification: ' + obj.lang; 29 | classificationLabel.style.display = 'block'; 30 | } 31 | 32 | window.addEventListener('load', function() { 33 | textInput = document.getElementById('text-input'); 34 | classificationLabel = document.getElementById('classification'); 35 | var classifyButton = document.getElementById('classify-button'); 36 | classifyButton.addEventListener('click', classify); 37 | }); 38 | 39 | })(); 40 | -------------------------------------------------------------------------------- /idtree/classifier.go: -------------------------------------------------------------------------------- 1 | package idtree 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/unixpickle/whichlang/tokens" 7 | ) 8 | 9 | type Classifier struct { 10 | LeafClassification *string 11 | 12 | Keyword string 13 | Threshold float64 14 | 15 | FalseBranch *Classifier 16 | TrueBranch *Classifier 17 | } 18 | 19 | func DecodeClassifier(d []byte) (*Classifier, error) { 20 | var res Classifier 21 | if err := json.Unmarshal(d, &res); err != nil { 22 | return nil, err 23 | } 24 | return &res, nil 25 | } 26 | 27 | func (c *Classifier) Classify(f tokens.Freqs) string { 28 | if c.LeafClassification == nil { 29 | if f[c.Keyword] > c.Threshold { 30 | return c.TrueBranch.Classify(f) 31 | } else { 32 | return c.FalseBranch.Classify(f) 33 | } 34 | } else { 35 | return *c.LeafClassification 36 | } 37 | } 38 | 39 | func (c *Classifier) Encode() []byte { 40 | res, _ := json.Marshal(c) 41 | return res 42 | } 43 | 44 | func (c *Classifier) Languages() []string { 45 | if c.LeafClassification != nil { 46 | return []string{*c.LeafClassification} 47 | } 48 | 49 | seen := map[string]bool{} 50 | for _, lang := range c.FalseBranch.Languages() { 51 | seen[lang] = true 52 | } 53 | for _, lang := range c.TrueBranch.Languages() { 54 | seen[lang] = true 55 | } 56 | 57 | res := make([]string, 0, len(seen)) 58 | for lang := range seen { 59 | res = append(res, lang) 60 | } 61 | return res 62 | } 63 | -------------------------------------------------------------------------------- /tokens/sample_counts_test.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestSampleCountsPruneFreqs(t *testing.T) { 9 | docs := SampleCounts{ 10 | "A": []Counts{ 11 | {"Foo": 1, "Bar": 3, "Baz": 2, "Once1": 1}, 12 | {"Foo": 1, "Bar": 1, "Once2": 15}, 13 | }, 14 | "B": []Counts{ 15 | {"Baz": 15, "Once3": 17}, 16 | }, 17 | } 18 | docs.Prune(1) 19 | actual := docs.SampleFreqs() 20 | expected := map[string][]Freqs{ 21 | "A": []Freqs{ 22 | {"Foo": 1.0 / 7.0, "Bar": 3.0 / 7.0, "Baz": 2.0 / 7.0}, 23 | {"Foo": 1.0 / 17.0, "Bar": 1.0 / 17.0}, 24 | }, 25 | "B": []Freqs{ 26 | {"Baz": 15.0 / 32.0}, 27 | }, 28 | } 29 | for lang, freqs := range expected { 30 | actualFreqs := actual[lang] 31 | if len(actualFreqs) != len(freqs) { 32 | t.Error("unexpected document count for", lang) 33 | continue 34 | } 35 | for i, actualFreq := range actualFreqs { 36 | expFreq := freqs[i] 37 | if !freqsApproxEqual(actualFreq, expFreq) { 38 | t.Error("incorrect freq", actualFreq) 39 | } 40 | } 41 | } 42 | for lang := range actual { 43 | if expected[lang] == nil { 44 | t.Error("unexpected language", lang) 45 | } 46 | } 47 | } 48 | 49 | func freqsApproxEqual(f1, f2 Freqs) bool { 50 | if len(f1) != len(f2) { 51 | return false 52 | } 53 | for key, val := range f1 { 54 | if math.Abs(val-f2[key]) > 1e-5 { 55 | return false 56 | } 57 | } 58 | return true 59 | } 60 | -------------------------------------------------------------------------------- /cmd/rater/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | 8 | "github.com/unixpickle/whichlang" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) != 4 { 14 | fmt.Fprintln(os.Stderr, "Usage: rater ") 15 | os.Exit(1) 16 | } 17 | 18 | decoder := whichlang.Decoders[os.Args[1]] 19 | if decoder == nil { 20 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1]) 21 | os.Exit(1) 22 | } 23 | 24 | classifierData, err := ioutil.ReadFile(os.Args[2]) 25 | if err != nil { 26 | fmt.Fprintln(os.Stderr, err) 27 | os.Exit(1) 28 | } 29 | 30 | classifier, err := decoder(classifierData) 31 | if err != nil { 32 | fmt.Fprintln(os.Stderr, "Failed to decode classifier:", err) 33 | os.Exit(1) 34 | } 35 | 36 | samples, err := tokens.ReadSampleCounts(os.Args[3]) 37 | if err != nil { 38 | fmt.Fprintln(os.Stderr, "Failed to read samples:", err) 39 | os.Exit(1) 40 | } 41 | 42 | rating := Rate(classifier, samples) 43 | 44 | fmt.Printf("Success rate: %d/%d or %0.2f%%\n", rating.Correct, rating.Total, 45 | 100*rating.Frac()) 46 | 47 | nameLength := len(rating.LongestLangName()) 48 | for _, rating := range rating.LangRatings { 49 | paddedName := rating.Language 50 | for len(paddedName) < nameLength { 51 | paddedName = " " + paddedName 52 | } 53 | fmt.Printf("%s success rate %d/%d or %0.2f%%\n", paddedName, 54 | rating.Correct, rating.Total, 100*rating.Frac()) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /idtree/samples.go: -------------------------------------------------------------------------------- 1 | package idtree 2 | 3 | import "github.com/unixpickle/whichlang/tokens" 4 | 5 | type linearSample struct { 6 | freqs []float64 7 | lang string 8 | } 9 | 10 | func freqsToLinearSamples(toks []string, freqs map[string][]tokens.Freqs) []linearSample { 11 | var res []linearSample 12 | for lang, freqsList := range freqs { 13 | for _, freqs := range freqsList { 14 | s := linearSample{ 15 | lang: lang, 16 | freqs: make([]float64, len(toks)), 17 | } 18 | for i, tok := range toks { 19 | s.freqs[i] = freqs[tok] 20 | } 21 | res = append(res, s) 22 | } 23 | } 24 | return res 25 | } 26 | 27 | func languageMajority(samples []linearSample) string { 28 | counts := map[string]int{} 29 | for _, sample := range samples { 30 | counts[sample.lang]++ 31 | } 32 | 33 | var maxCount int 34 | var maxLang string 35 | for lang, count := range counts { 36 | if count > maxCount { 37 | maxCount = count 38 | maxLang = lang 39 | } 40 | } 41 | 42 | return maxLang 43 | } 44 | 45 | // A sampleSorter implements sort.Interface 46 | // and facilitates sorting linear samples 47 | // by the frequency of a given token. 48 | type sampleSorter struct { 49 | samples []linearSample 50 | tokenIdx int 51 | } 52 | 53 | func (s *sampleSorter) Len() int { 54 | return len(s.samples) 55 | } 56 | 57 | func (s *sampleSorter) Swap(i, j int) { 58 | s.samples[i], s.samples[j] = s.samples[j], s.samples[i] 59 | } 60 | 61 | func (s *sampleSorter) Less(i, j int) bool { 62 | f1 := s.samples[i].freqs[s.tokenIdx] 63 | f2 := s.samples[j].freqs[s.tokenIdx] 64 | return f1 < f2 65 | } 66 | -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "os" 9 | 10 | "github.com/unixpickle/whichlang" 11 | "github.com/unixpickle/whichlang/tokens" 12 | ) 13 | 14 | func main() { 15 | if len(os.Args) != 5 { 16 | fmt.Fprintln(os.Stderr, "Usage: server ") 17 | os.Exit(1) 18 | } 19 | 20 | classifier := readClassifier() 21 | assets := os.Args[3] 22 | 23 | http.HandleFunc("/classify", func(w http.ResponseWriter, r *http.Request) { 24 | contents, err := ioutil.ReadAll(r.Body) 25 | if err != nil { 26 | http.Error(w, err.Error(), http.StatusInternalServerError) 27 | return 28 | } 29 | counts := tokens.CountTokens(string(contents)) 30 | freqs := counts.Freqs() 31 | lang := classifier.Classify(freqs) 32 | jsonObj := map[string]interface{}{"lang": lang} 33 | jsonData, _ := json.Marshal(jsonObj) 34 | w.Header().Set("Content-Type", "application/json") 35 | w.Write(jsonData) 36 | }) 37 | http.Handle("/", http.FileServer(http.Dir(assets))) 38 | 39 | if err := http.ListenAndServe(":"+os.Args[4], nil); err != nil { 40 | fmt.Fprintln(os.Stderr, err) 41 | os.Exit(1) 42 | } 43 | } 44 | 45 | func readClassifier() whichlang.Classifier { 46 | decoder := whichlang.Decoders[os.Args[1]] 47 | if decoder == nil { 48 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1]) 49 | os.Exit(1) 50 | } 51 | 52 | data, err := ioutil.ReadFile(os.Args[2]) 53 | if err != nil { 54 | fmt.Fprintln(os.Stderr, err) 55 | os.Exit(1) 56 | } 57 | 58 | c, err := decoder(data) 59 | if err != nil { 60 | fmt.Fprintln(os.Stderr, err) 61 | os.Exit(1) 62 | } 63 | 64 | return c 65 | } 66 | -------------------------------------------------------------------------------- /gaussbayes/classifier.go: -------------------------------------------------------------------------------- 1 | // Package gaussbayes implements naive Bayesian 2 | // classification using the assumption that token 3 | // frequencies follow Gaussian distributions. 4 | package gaussbayes 5 | 6 | import ( 7 | "encoding/json" 8 | "math" 9 | 10 | "github.com/unixpickle/whichlang/tokens" 11 | ) 12 | 13 | // Gaussian is a Gaussian probability distribution. 14 | type Gaussian struct { 15 | Mean float64 16 | Variance float64 17 | } 18 | 19 | // EvalLog evaluates the natural logarithm of the 20 | // density function at a given x value. 21 | func (g Gaussian) EvalLog(x float64) float64 { 22 | coeff := 1 / math.Sqrt(2*g.Variance*math.Pi) 23 | exp := -math.Pow(x-g.Mean, 2) / (2 * g.Variance) 24 | return math.Log(coeff) + exp 25 | } 26 | 27 | type Classifier struct { 28 | LangGaussians map[string]map[string]Gaussian 29 | } 30 | 31 | func DecodeClassifier(d []byte) (*Classifier, error) { 32 | var c Classifier 33 | if err := json.Unmarshal(d, &c); err != nil { 34 | return nil, err 35 | } 36 | return &c, nil 37 | } 38 | 39 | func (c *Classifier) Classify(f tokens.Freqs) string { 40 | var bestLanguage string 41 | var bestLogProbability float64 42 | for lang, dists := range c.LangGaussians { 43 | var probLog float64 44 | for token, gaussian := range dists { 45 | probLog += gaussian.EvalLog(f[token]) 46 | } 47 | if bestLanguage == "" || probLog > bestLogProbability { 48 | bestLanguage = lang 49 | bestLogProbability = probLog 50 | } 51 | } 52 | return bestLanguage 53 | } 54 | 55 | func (c *Classifier) Encode() []byte { 56 | res, _ := json.Marshal(c) 57 | return res 58 | } 59 | 60 | func (c *Classifier) Languages() []string { 61 | var languages []string 62 | for lang := range c.LangGaussians { 63 | languages = append(languages, lang) 64 | } 65 | return languages 66 | } 67 | -------------------------------------------------------------------------------- /cmd/svm-shrink/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | 9 | "github.com/unixpickle/num-analysis/linalg" 10 | "github.com/unixpickle/whichlang/svm" 11 | ) 12 | 13 | func main() { 14 | if len(os.Args) != 3 { 15 | fmt.Println("Usage: svm-shrink ") 16 | os.Exit(1) 17 | } 18 | 19 | data, err := ioutil.ReadFile(os.Args[1]) 20 | if err != nil { 21 | die(err) 22 | } 23 | 24 | classifier, err := svm.DecodeClassifier(data) 25 | if err != nil { 26 | die(err) 27 | } else if classifier.Kernel.Type != svm.LinearKernel { 28 | die(errors.New("can only shrink linear classifiers")) 29 | } 30 | 31 | langs := classifier.Languages() 32 | 33 | newClassifier := &svm.Classifier{ 34 | Keywords: classifier.Keywords, 35 | Kernel: classifier.Kernel, 36 | SampleVectors: make([]linalg.Vector, len(langs)), 37 | Classifiers: map[string]svm.BinaryClassifier{}, 38 | } 39 | 40 | for i, lang := range langs { 41 | newClassifier.SampleVectors[i] = combineLanguageVecs(classifier, lang) 42 | bc := svm.BinaryClassifier{ 43 | SupportVectors: []int{i}, 44 | Weights: []float64{1}, 45 | Threshold: classifier.Classifiers[lang].Threshold, 46 | } 47 | newClassifier.Classifiers[lang] = bc 48 | } 49 | 50 | encoded := newClassifier.Encode() 51 | if err := ioutil.WriteFile(os.Args[2], encoded, 0755); err != nil { 52 | die(err) 53 | } 54 | } 55 | 56 | func combineLanguageVecs(c *svm.Classifier, lang string) linalg.Vector { 57 | sum := make(linalg.Vector, len(c.Keywords)) 58 | bc := c.Classifiers[lang] 59 | for i, idx := range bc.SupportVectors { 60 | sum.Add(c.SampleVectors[idx].Copy().Scale(bc.Weights[i])) 61 | } 62 | return sum 63 | } 64 | 65 | func die(e error) { 66 | fmt.Fprintln(os.Stderr, e) 67 | os.Exit(1) 68 | } 69 | -------------------------------------------------------------------------------- /cmd/fetchlang/repo_search.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "net/url" 6 | ) 7 | 8 | // Search asynchronously lists repositories which 9 | // are written in the given programming language. 10 | // Repository names are of the form "user/repo". 11 | // 12 | // The caller should close the done argument when 13 | // they do not need any more results. 14 | // When no results are left, or when done is closed, 15 | // or on error, both returned channels are closed. 16 | // 17 | // If search results cannot be obtained, an error 18 | // is sent on the error channel. 19 | func (g *GithubClient) Search(lang string, done <-chan struct{}) (<-chan string, <-chan error) { 20 | nameChan := make(chan string, 0) 21 | errChan := make(chan error, 1) 22 | go func() { 23 | defer func() { 24 | close(nameChan) 25 | close(errChan) 26 | }() 27 | u := repositorySearchURL(lang) 28 | for u != nil { 29 | select { 30 | case <-done: 31 | return 32 | default: 33 | } 34 | 35 | body, next, err := g.request(u.String()) 36 | if err != nil { 37 | errChan <- err 38 | return 39 | } 40 | u = next 41 | 42 | var obj struct { 43 | Items []struct { 44 | FullName string `json:"full_name"` 45 | } `json:"items"` 46 | } 47 | if err := json.Unmarshal(body, &obj); err != nil { 48 | errChan <- err 49 | return 50 | } 51 | 52 | for _, x := range obj.Items { 53 | select { 54 | case nameChan <- x.FullName: 55 | case <-done: 56 | return 57 | } 58 | } 59 | } 60 | }() 61 | return nameChan, errChan 62 | } 63 | 64 | func repositorySearchURL(lang string) *url.URL { 65 | return &url.URL{ 66 | Scheme: "https", 67 | Host: "api.github.com", 68 | Path: "search/repositories", 69 | RawQuery: url.Values{ 70 | "order": []string{"desc"}, 71 | "q": []string{"language:" + lang}, 72 | }.Encode(), 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /cmd/rater/rate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | 7 | "github.com/unixpickle/whichlang" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | type Challenge struct { 12 | Language string 13 | Sample tokens.Freqs 14 | } 15 | 16 | type Result struct { 17 | Language string 18 | Correct bool 19 | } 20 | 21 | func Rate(c whichlang.Classifier, s map[string][]tokens.Counts) *OverallRating { 22 | var wg sync.WaitGroup 23 | challengeChan := make(chan Challenge, 0) 24 | resultChan := make(chan Result, 0) 25 | 26 | for i := 0; i < runtime.GOMAXPROCS(0); i++ { 27 | wg.Add(1) 28 | go func() { 29 | defer wg.Done() 30 | for challenge := range challengeChan { 31 | correct := (c.Classify(challenge.Sample) == challenge.Language) 32 | resultChan <- Result{ 33 | Language: challenge.Language, 34 | Correct: correct, 35 | } 36 | } 37 | }() 38 | } 39 | 40 | go func() { 41 | for lang, langSamples := range s { 42 | for _, sample := range langSamples { 43 | challengeChan <- Challenge{lang, sample.Freqs()} 44 | } 45 | } 46 | close(challengeChan) 47 | }() 48 | 49 | go func() { 50 | wg.Wait() 51 | close(resultChan) 52 | }() 53 | 54 | var total, successes int 55 | langSuccesses := map[string]int{} 56 | langTotals := map[string]int{} 57 | 58 | for result := range resultChan { 59 | total++ 60 | langTotals[result.Language]++ 61 | if result.Correct { 62 | successes++ 63 | langSuccesses[result.Language]++ 64 | } 65 | } 66 | 67 | langRatings := makeLangRatings(langSuccesses, langTotals) 68 | return NewOverallRating(successes, total, langRatings) 69 | } 70 | 71 | func makeLangRatings(succ, total map[string]int) []*LangRating { 72 | res := make([]*LangRating, 0, len(total)) 73 | for lang, totalCount := range total { 74 | lr := &LangRating{ 75 | Rating: Rating{ 76 | Correct: succ[lang], 77 | Total: totalCount, 78 | }, 79 | Language: lang, 80 | } 81 | res = append(res, lr) 82 | } 83 | return res 84 | } 85 | -------------------------------------------------------------------------------- /svm/kernel.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strconv" 7 | 8 | "github.com/unixpickle/num-analysis/linalg" 9 | ) 10 | 11 | type KernelType int 12 | 13 | const ( 14 | // LinearKernel generates a linear classifier 15 | // with no parameters. 16 | LinearKernel KernelType = iota 17 | 18 | // PolynomialKernel computes inner products as 19 | // (x*y + k1)^k2, where k1 and k2 are parameters. 20 | PolynomialKernel 21 | 22 | // RadialBasisKernel computes inner products as 23 | // exp(-k1*||x-y||^2), where k1 is a parameter. 24 | RadialBasisKernel 25 | ) 26 | 27 | // A Kernel computes inner products of vectors 28 | // after transforming them into some space. 29 | type Kernel struct { 30 | Type KernelType 31 | Params []float64 32 | } 33 | 34 | // Product returns the product of two vectors 35 | // under this kernel. 36 | func (k *Kernel) Product(v1, v2 linalg.Vector) float64 { 37 | switch k.Type { 38 | case LinearKernel: 39 | return v1.Dot(v2) 40 | case PolynomialKernel: 41 | if len(k.Params) != 2 { 42 | panic("expected two parameters for polynomial kernel") 43 | } 44 | return math.Pow(v1.Dot(v2)+k.Params[0], k.Params[1]) 45 | case RadialBasisKernel: 46 | if len(k.Params) != 1 { 47 | panic("expected one parameter for radial basis kernel") 48 | } 49 | diff := v1.Copy().Scale(-1).Add(v2) 50 | return math.Exp(-k.Params[0] * diff.Dot(diff)) 51 | default: 52 | panic("unknown kernel type: " + strconv.Itoa(int(k.Type))) 53 | } 54 | } 55 | 56 | // String returns a mathematical formula which 57 | // represents this kernel (e.g. "(x*y+1)^2"). 58 | func (k *Kernel) String() string { 59 | switch k.Type { 60 | case LinearKernel: 61 | return "x*y" 62 | case PolynomialKernel: 63 | if len(k.Params) != 2 { 64 | panic("expected two parameters for polynomial kernel") 65 | } 66 | return fmt.Sprintf("(x*y + %f)^%f", k.Params[0], k.Params[1]) 67 | case RadialBasisKernel: 68 | if len(k.Params) != 1 { 69 | panic("expected one parameter for radial basis kernel") 70 | } 71 | return fmt.Sprintf("exp(-%f*(x*y)^2)", k.Params[0]) 72 | default: 73 | panic("unknown kernel type: " + strconv.Itoa(int(k.Type))) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /cmd/trainer/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "math/rand" 7 | "os" 8 | "strconv" 9 | "time" 10 | 11 | "github.com/unixpickle/whichlang" 12 | "github.com/unixpickle/whichlang/tokens" 13 | ) 14 | 15 | const HelpColumnSize = 10 16 | 17 | func main() { 18 | // Several machine learning algorithms depend on 19 | // random starting positions. 20 | rand.Seed(time.Now().UnixNano()) 21 | 22 | if len(os.Args) != 5 { 23 | dieUsage() 24 | } 25 | 26 | algorithm := os.Args[1] 27 | 28 | trainer := whichlang.Trainers[algorithm] 29 | if trainer == nil { 30 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", algorithm) 31 | dieUsage() 32 | } 33 | 34 | ubiquity, err := strconv.Atoi(os.Args[2]) 35 | if err != nil { 36 | fmt.Fprintln(os.Stderr, "Invalid ubiquity:", ubiquity, "(expected integer)") 37 | os.Exit(1) 38 | } 39 | 40 | sampleDir := os.Args[3] 41 | outputFile := os.Args[4] 42 | 43 | counts, err := tokens.ReadSampleCounts(sampleDir) 44 | if err != nil { 45 | fmt.Fprintln(os.Stderr, err) 46 | os.Exit(1) 47 | } 48 | 49 | oldCount := counts.NumTokens() 50 | fmt.Println("Pruning tokens...") 51 | counts.Prune(ubiquity) 52 | newCount := counts.NumTokens() 53 | fmt.Printf("Pruned %d/%d tokens (%d left).\n", (oldCount - newCount), 54 | oldCount, newCount) 55 | 56 | freqs := counts.SampleFreqs() 57 | 58 | fmt.Println("Training...") 59 | classifier := trainer(freqs) 60 | 61 | fmt.Println("Saving...") 62 | data := classifier.Encode() 63 | 64 | if err := ioutil.WriteFile(outputFile, data, 0755); err != nil { 65 | fmt.Fprintln(os.Stderr, "Error writing file:", err) 66 | os.Exit(1) 67 | } 68 | } 69 | 70 | func dieUsage() { 71 | fmt.Fprintln(os.Stderr, "Usage: trainer \n\n"+ 72 | " (ubiquity specifies the number of files in which a\n keyword should appear.)\n\n"+ 73 | "Available algorithms:") 74 | for _, name := range whichlang.ClassifierNames { 75 | spaces := "" 76 | for i := len(name); i < HelpColumnSize; i++ { 77 | spaces += " " 78 | } 79 | fmt.Fprintln(os.Stderr, " "+name+spaces, whichlang.Descriptions[name]) 80 | } 81 | fmt.Fprintln(os.Stderr, "") 82 | os.Exit(1) 83 | } 84 | -------------------------------------------------------------------------------- /neuralnet/env.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import ( 4 | "math" 5 | "os" 6 | "strconv" 7 | ) 8 | 9 | const DefaultMaxIterations = 6400 10 | 11 | // DefaultHiddenLayerScale specifies how much 12 | // larger the hidden layer is than the output 13 | // layer, by default. 14 | const DefaultHiddenLayerScale = 2.0 15 | 16 | // VerboseEnvVar is an environment variable 17 | // which can be set to "1" to make the 18 | // neuralnet print out status reports. 19 | var VerboseEnvVar = "NEURALNET_VERBOSE" 20 | 21 | // VerboseStepsEnvVar is an environment 22 | // variable which can be set to "1" to make 23 | // neuralnet print out status reports after 24 | // each iteration of gradient descent. 25 | var VerboseStepsEnvVar = "NEURALNET_VERBOSE_STEPS" 26 | 27 | // StepSizeEnvVar is an environment variable 28 | // which can be used to specify the step size 29 | // for use in gradient descent. 30 | var StepSizeEnvVar = "NEURALNET_STEP_SIZE" 31 | 32 | // MaxItersEnvVar is an environment variable 33 | // specifying the maximum number of iterations 34 | // of gradient descent to perform. 35 | var MaxItersEnvVar = "NEURALNET_MAX_ITERS" 36 | 37 | // HiddenSizeEnvVar is an environment variable 38 | // specifying the number of hidden neurons. 39 | var HiddenSizeEnvVar = "NEURALNET_HIDDEN_SIZE" 40 | 41 | func verboseFlag() bool { 42 | return os.Getenv(VerboseEnvVar) == "1" 43 | } 44 | 45 | func verboseStepsFlag() bool { 46 | return os.Getenv(VerboseStepsEnvVar) == "1" 47 | } 48 | 49 | func stepSizes() []float64 { 50 | if stepSize := os.Getenv(StepSizeEnvVar); stepSize == "" { 51 | var res []float64 52 | for power := -20; power < 10; power++ { 53 | res = append(res, math.Pow(2, float64(power))) 54 | } 55 | return res 56 | } else { 57 | val, err := strconv.ParseFloat(stepSize, 64) 58 | if err != nil { 59 | panic(err) 60 | } 61 | return []float64{val} 62 | } 63 | } 64 | 65 | func maxIterations() int { 66 | if max := os.Getenv(MaxItersEnvVar); max == "" { 67 | return DefaultMaxIterations 68 | } else { 69 | val, err := strconv.Atoi(max) 70 | if err != nil { 71 | panic(err) 72 | } 73 | return val 74 | } 75 | } 76 | 77 | func hiddenSize(outputCount int) int { 78 | if size := os.Getenv(HiddenSizeEnvVar); size == "" { 79 | return int(float64(outputCount)*DefaultHiddenLayerScale + 0.5) 80 | } else { 81 | val, err := strconv.Atoi(size) 82 | if err != nil { 83 | panic(err) 84 | } 85 | return val 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /cmd/subsamples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) != 3 { 14 | fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) 15 | os.Exit(1) 16 | } 17 | numLines, err := strconv.Atoi(os.Args[1]) 18 | if err != nil { 19 | fmt.Fprintln(os.Stderr, err) 20 | os.Exit(1) 21 | } 22 | sampleDir := os.Args[2] 23 | if err := subsample(numLines, sampleDir); err != nil { 24 | fmt.Fprintln(os.Stderr, err) 25 | os.Exit(1) 26 | } 27 | } 28 | 29 | func subsample(numLines int, dirPath string) error { 30 | dir, err := os.Open(dirPath) 31 | if err != nil { 32 | return err 33 | } 34 | listing, err := dir.Readdir(-1) 35 | dir.Close() 36 | if err != nil { 37 | return err 38 | } 39 | for _, langDir := range listing { 40 | if !langDir.IsDir() { 41 | continue 42 | } 43 | p := filepath.Join(dirPath, langDir.Name()) 44 | if err := subsampleLang(numLines, p); err != nil { 45 | return err 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | func subsampleLang(numLines int, langDirPath string) error { 52 | dir, err := os.Open(langDirPath) 53 | if err != nil { 54 | return err 55 | } 56 | listing, err := dir.Readdir(-1) 57 | dir.Close() 58 | if err != nil { 59 | return err 60 | } 61 | for _, fileInfo := range listing { 62 | if strings.HasPrefix(fileInfo.Name(), ".") { 63 | continue 64 | } 65 | filePath := filepath.Join(langDirPath, fileInfo.Name()) 66 | if err := subsampleFile(numLines, filePath); err != nil { 67 | return err 68 | } 69 | } 70 | return nil 71 | } 72 | 73 | func subsampleFile(numLines int, filePath string) error { 74 | contents, err := ioutil.ReadFile(filePath) 75 | if err != nil { 76 | return err 77 | } 78 | lines := strings.Split(string(contents), "\n") 79 | if len(lines) < numLines { 80 | fmt.Println("Skipping file", filePath) 81 | return nil 82 | } 83 | startIndex := (len(lines) - numLines) / 2 84 | splitLines := lines[startIndex : startIndex+numLines] 85 | newPath := newPath(numLines, filePath) 86 | newData := strings.Join(splitLines, "\n") 87 | return ioutil.WriteFile(newPath, []byte(newData), 0755) 88 | } 89 | 90 | func newPath(numLines int, filePath string) string { 91 | sampleTag := "_subsample_" + strconv.Itoa(numLines) 92 | ext := filepath.Ext(filePath) 93 | if ext == "" { 94 | return filePath + sampleTag 95 | } else { 96 | return filePath[:len(filePath)-len(ext)] + sampleTag + ext 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /svm/classifier.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "encoding/json" 5 | "math" 6 | 7 | "github.com/unixpickle/num-analysis/kahan" 8 | "github.com/unixpickle/num-analysis/linalg" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | // BinaryClassifier stores info for the 13 | // binary classifiers used in a Classifier. 14 | type BinaryClassifier struct { 15 | // SupportVectors stores indices to 16 | // elements of Classifier.SampleVectors. 17 | SupportVectors []int 18 | 19 | // Weights are the corresponding weights 20 | // for each of the support vectors. 21 | Weights []float64 22 | 23 | Threshold float64 24 | } 25 | 26 | // Classifier uses one-against-all SVMs to 27 | // classify source files. 28 | type Classifier struct { 29 | Keywords []string 30 | Kernel *Kernel 31 | 32 | SampleVectors []linalg.Vector 33 | 34 | // Classifiers maps each language to its 35 | // corresponding one-against-all binary 36 | // classifier. 37 | Classifiers map[string]BinaryClassifier 38 | } 39 | 40 | func DecodeClassifier(d []byte) (*Classifier, error) { 41 | var c Classifier 42 | if err := json.Unmarshal(d, &c); err != nil { 43 | return nil, err 44 | } 45 | return &c, nil 46 | } 47 | 48 | func (c *Classifier) Classify(sample tokens.Freqs) string { 49 | products := c.sampleProducts(sample) 50 | 51 | var bestLanguage string 52 | bestClassification := math.Inf(-1) 53 | 54 | for lang, classifier := range c.Classifiers { 55 | productSum := kahan.NewSummer64() 56 | for i, vecIdx := range classifier.SupportVectors { 57 | productSum.Add(products[vecIdx] * classifier.Weights[i]) 58 | } 59 | productSum.Add(-classifier.Threshold) 60 | if productSum.Sum() > bestClassification { 61 | bestClassification = productSum.Sum() 62 | bestLanguage = lang 63 | } 64 | } 65 | 66 | return bestLanguage 67 | } 68 | 69 | func (c *Classifier) Encode() []byte { 70 | res, _ := json.Marshal(c) 71 | return res 72 | } 73 | 74 | func (c *Classifier) Languages() []string { 75 | res := make([]string, 0, len(c.Classifiers)) 76 | for lang := range c.Classifiers { 77 | res = append(res, lang) 78 | } 79 | return res 80 | } 81 | 82 | func (c *Classifier) sampleProducts(sample tokens.Freqs) []float64 { 83 | vec := c.sampleVector(sample) 84 | res := make([]float64, len(c.SampleVectors)) 85 | for i, s := range c.SampleVectors { 86 | res[i] = c.Kernel.Product(s, vec) 87 | } 88 | return res 89 | } 90 | 91 | func (c *Classifier) sampleVector(sample tokens.Freqs) linalg.Vector { 92 | vec := make(linalg.Vector, len(c.Keywords)) 93 | for i, keyword := range c.Keywords { 94 | vec[i] = sample[keyword] 95 | } 96 | return vec 97 | } 98 | -------------------------------------------------------------------------------- /cmd/fetchlang/github.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "regexp" 12 | "strings" 13 | 14 | "github.com/howeyc/gopass" 15 | ) 16 | 17 | // A GithubClient uses the Github API 18 | // on behalf of a given user. 19 | type GithubClient struct { 20 | User string 21 | Pass string 22 | } 23 | 24 | // PromptGithubClient prompts the user for their 25 | // Github account details, then generates a 26 | // *GithubClient based on these details. 27 | func PromptGithubClient() (*GithubClient, error) { 28 | fmt.Print("Username: ") 29 | username := "" 30 | for { 31 | var ch [1]byte 32 | _, err := os.Stdin.Read(ch[:]) 33 | if err != nil { 34 | return nil, err 35 | } else if ch[0] == '\n' { 36 | break 37 | } else if ch[0] == '\r' { 38 | continue 39 | } 40 | username += string(ch[0]) 41 | } 42 | 43 | fmt.Print("Password: ") 44 | password, err := gopass.GetPasswd() 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | return &GithubClient{ 50 | User: strings.TrimSpace(username), 51 | Pass: string(password), 52 | }, nil 53 | } 54 | 55 | // request accesses an API URL using the 56 | // user's credentials. 57 | // 58 | // It returns an error if the request fails, 59 | // or if Github's API returns an error. 60 | // 61 | // Some requests are naturally paginated, in 62 | // which case the next return argument 63 | // corresponds to the URL of the next page. 64 | func (g *GithubClient) request(u string) (data []byte, next *url.URL, err error) { 65 | req, err := http.NewRequest("GET", u, nil) 66 | if err != nil { 67 | return nil, nil, err 68 | } 69 | req.Header.Set("Accept", "application/vnd.github.preview.text-match+json") 70 | req.SetBasicAuth(g.User, g.Pass) 71 | res, err := http.DefaultClient.Do(req) 72 | if res != nil { 73 | defer res.Body.Close() 74 | } 75 | if err != nil { 76 | return nil, nil, err 77 | } 78 | body, err := ioutil.ReadAll(res.Body) 79 | if err != nil { 80 | return nil, nil, err 81 | } 82 | 83 | var raw map[string]interface{} 84 | if err := json.Unmarshal(body, &raw); err == nil { 85 | if message, ok := raw["message"]; ok { 86 | if s, ok := message.(string); ok { 87 | return nil, nil, errors.New(s) 88 | } 89 | } 90 | } 91 | 92 | nextPattern := regexp.MustCompile("\\<(.*?)\\>; rel=\"next\"") 93 | match := nextPattern.FindStringSubmatch(res.Header.Get("Link")) 94 | if match != nil { 95 | u, _ := url.Parse(match[1]) 96 | return body, u, nil 97 | } 98 | 99 | return body, nil, nil 100 | } 101 | -------------------------------------------------------------------------------- /knn/classifier.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "encoding/json" 5 | "math" 6 | 7 | "github.com/unixpickle/num-analysis/linalg" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | type Sample struct { 12 | Language string 13 | Vector linalg.Vector 14 | } 15 | 16 | type Classifier struct { 17 | Tokens []string 18 | Samples []Sample 19 | 20 | NeighborCount int 21 | } 22 | 23 | func DecodeClassifier(d []byte) (*Classifier, error) { 24 | var res Classifier 25 | if err := json.Unmarshal(d, &res); err != nil { 26 | return nil, err 27 | } 28 | return &res, nil 29 | } 30 | 31 | func (c *Classifier) Classify(f tokens.Freqs) string { 32 | vec := make(linalg.Vector, len(c.Tokens)) 33 | for i, keyw := range c.Tokens { 34 | vec[i] = f[keyw] 35 | } 36 | 37 | vecMag := vec.Dot(vec) 38 | if vecMag == 0 { 39 | return c.Samples[0].Language 40 | } 41 | vec.Scale(1 / math.Sqrt(vecMag)) 42 | 43 | return c.classifyVector(vec) 44 | } 45 | 46 | func (c *Classifier) Encode() []byte { 47 | data, _ := json.Marshal(c) 48 | return data 49 | } 50 | 51 | func (c *Classifier) Languages() []string { 52 | seenLangs := map[string]bool{} 53 | for _, sample := range c.Samples { 54 | seenLangs[sample.Language] = true 55 | } 56 | res := make([]string, 0, len(seenLangs)) 57 | for lang := range seenLangs { 58 | res = append(res, lang) 59 | } 60 | return res 61 | } 62 | 63 | func (c *Classifier) classifyVector(vec linalg.Vector) string { 64 | matches := make([]match, 0, c.NeighborCount) 65 | for _, sample := range c.Samples { 66 | correlation := sample.Vector.Dot(vec) 67 | insertIdx := matchInsertionIndex(matches, correlation) 68 | if insertIdx >= c.NeighborCount { 69 | continue 70 | } 71 | if len(matches) < c.NeighborCount { 72 | matches = append(matches, match{}) 73 | } 74 | copy(matches[insertIdx+1:], matches[insertIdx:]) 75 | matches[insertIdx] = match{ 76 | Language: sample.Language, 77 | Correlation: correlation, 78 | } 79 | } 80 | 81 | return dominantClassification(matches) 82 | } 83 | 84 | func dominantClassification(matches []match) string { 85 | scores := map[string]float64{} 86 | for _, m := range matches { 87 | scores[m.Language] += m.Correlation 88 | } 89 | 90 | var bestLang string 91 | bestScore := math.Inf(-1) 92 | for lang, score := range scores { 93 | if score > bestScore { 94 | bestScore = score 95 | bestLang = lang 96 | } 97 | } 98 | 99 | return bestLang 100 | } 101 | 102 | type match struct { 103 | Language string 104 | Correlation float64 105 | } 106 | 107 | func matchInsertionIndex(m []match, corr float64) int { 108 | for i, x := range m { 109 | if x.Correlation < corr { 110 | return i 111 | } 112 | } 113 | return len(m) 114 | } 115 | -------------------------------------------------------------------------------- /gaussbayes/train.go: -------------------------------------------------------------------------------- 1 | package gaussbayes 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/unixpickle/whichlang/tokens" 7 | ) 8 | 9 | // Train returns a *Classifier by computing statistical 10 | // properties of the sample data. 11 | func Train(freqs map[string][]tokens.Freqs) *Classifier { 12 | res := &Classifier{LangGaussians: map[string]map[string]Gaussian{}} 13 | tokens := allTokens(freqs) 14 | for lang, samples := range freqs { 15 | gaussians := computeGaussians(samples) 16 | addMissing(gaussians, tokens) 17 | res.LangGaussians[lang] = gaussians 18 | } 19 | regularizeVariances(res) 20 | return res 21 | } 22 | 23 | func computeGaussians(samples []tokens.Freqs) map[string]Gaussian { 24 | res := map[string]Gaussian{} 25 | 26 | computeMeans(samples, res) 27 | computeVariances(samples, res) 28 | 29 | return res 30 | } 31 | 32 | func computeMeans(samples []tokens.Freqs, out map[string]Gaussian) { 33 | for _, sample := range samples { 34 | for keyword, freq := range sample { 35 | outGaussian := out[keyword] 36 | outGaussian.Mean += freq 37 | out[keyword] = outGaussian 38 | } 39 | } 40 | meanScaler := 1 / float64(len(samples)) 41 | for key, g := range out { 42 | g.Mean *= meanScaler 43 | out[key] = g 44 | } 45 | } 46 | 47 | func computeVariances(samples []tokens.Freqs, out map[string]Gaussian) { 48 | for _, sample := range samples { 49 | for keyword, freq := range sample { 50 | outGaussian := out[keyword] 51 | outGaussian.Variance += math.Pow(freq-outGaussian.Mean, 2) 52 | out[keyword] = outGaussian 53 | } 54 | } 55 | varianceScaler := 1 / float64(len(samples)) 56 | for key, g := range out { 57 | g.Variance *= varianceScaler 58 | out[key] = g 59 | } 60 | } 61 | 62 | func addMissing(m map[string]Gaussian, tokens []string) { 63 | for _, token := range tokens { 64 | if _, ok := m[token]; !ok { 65 | m[token] = Gaussian{} 66 | } 67 | } 68 | } 69 | 70 | func allTokens(m map[string][]tokens.Freqs) []string { 71 | res := map[string]bool{} 72 | for _, x := range m { 73 | for _, t := range x { 74 | for w := range t { 75 | res[w] = true 76 | } 77 | } 78 | } 79 | resSlice := make([]string, 0, len(res)) 80 | for w := range res { 81 | resSlice = append(resSlice, w) 82 | } 83 | return resSlice 84 | } 85 | 86 | // regularizeVariances ensures that no variances are zero. 87 | func regularizeVariances(c *Classifier) { 88 | var smallestVariance float64 89 | for _, m := range c.LangGaussians { 90 | for _, x := range m { 91 | if smallestVariance == 0 || (x.Variance < smallestVariance && x.Variance > 0) { 92 | smallestVariance = x.Variance 93 | } 94 | } 95 | } 96 | for _, m := range c.LangGaussians { 97 | for word, x := range m { 98 | if x.Variance == 0 { 99 | x.Variance = smallestVariance 100 | m[word] = x 101 | } 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // package whichlang is a suite of Machine Learning 2 | // tools to classify programming languages. 3 | package whichlang 4 | 5 | import ( 6 | "github.com/unixpickle/whichlang/gaussbayes" 7 | "github.com/unixpickle/whichlang/idtree" 8 | "github.com/unixpickle/whichlang/knn" 9 | "github.com/unixpickle/whichlang/neuralnet" 10 | "github.com/unixpickle/whichlang/svm" 11 | "github.com/unixpickle/whichlang/tokens" 12 | ) 13 | 14 | type Classifier interface { 15 | // Classify classifies a tokenized source file. 16 | // 17 | // The returned string is the name of the 18 | // programming language in which the file is 19 | // most likely written in. 20 | Classify(tokens.Freqs) string 21 | 22 | // Languages returns all possible languages 23 | // that Classify() might return. 24 | // The result is not sorted, and its order 25 | // may change across calls. 26 | // Callers should not modify the returned slice. 27 | Languages() []string 28 | 29 | // Encode serializes this classifier as binary 30 | // data. 31 | Encode() []byte 32 | } 33 | 34 | // A Trainer generates a Classifier using 35 | // a collection of tokenized sample files. 36 | type Trainer func(map[string][]tokens.Freqs) Classifier 37 | 38 | // A Decoder decodes a certain type of 39 | // Classifier from binary data. 40 | type Decoder func(d []byte) (Classifier, error) 41 | 42 | // ClassifierNames is an array containing the 43 | // names of every supported classifier. 44 | var ClassifierNames = []string{"idtree", "neuralnet", "knn", "svm", "gaussbayes"} 45 | 46 | // Trainers maps classifier names to their 47 | // corresponding Trainers. 48 | var Trainers = map[string]Trainer{ 49 | "idtree": func(freqs map[string][]tokens.Freqs) Classifier { 50 | return idtree.Train(freqs) 51 | }, 52 | "neuralnet": func(freqs map[string][]tokens.Freqs) Classifier { 53 | return neuralnet.Train(freqs) 54 | }, 55 | "knn": func(freqs map[string][]tokens.Freqs) Classifier { 56 | return knn.Train(freqs) 57 | }, 58 | "svm": func(freqs map[string][]tokens.Freqs) Classifier { 59 | return svm.Train(freqs) 60 | }, 61 | "gaussbayes": func(freqs map[string][]tokens.Freqs) Classifier { 62 | return gaussbayes.Train(freqs) 63 | }, 64 | } 65 | 66 | // Decoders maps classifier names to their 67 | // corresponding Decoders. 68 | var Decoders = map[string]Decoder{ 69 | "idtree": func(d []byte) (Classifier, error) { 70 | return idtree.DecodeClassifier(d) 71 | }, 72 | "neuralnet": func(d []byte) (Classifier, error) { 73 | return neuralnet.DecodeNetwork(d) 74 | }, 75 | "knn": func(d []byte) (Classifier, error) { 76 | return knn.DecodeClassifier(d) 77 | }, 78 | "svm": func(d []byte) (Classifier, error) { 79 | return svm.DecodeClassifier(d) 80 | }, 81 | "gaussbayes": func(d []byte) (Classifier, error) { 82 | return gaussbayes.DecodeClassifier(d) 83 | }, 84 | } 85 | 86 | // Descriptions maps classifier names to 87 | // one-line descriptions of the classifier. 88 | var Descriptions = map[string]string{ 89 | "idtree": "decision trees generated with ID3", 90 | "neuralnet": "feedforward neural network", 91 | "knn": "K-nearest neighbors", 92 | "svm": "support vector machines", 93 | "gaussbayes": "naive Bayes with Gaussians", 94 | } 95 | -------------------------------------------------------------------------------- /knn/trainer.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | 7 | "github.com/unixpickle/num-analysis/linalg" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | // crossValidationFrac specifies the fraction of 12 | // samples which are used for cross-validation 13 | // when determining the optimal k-value. 14 | const crossValidationFrac = 0.3 15 | 16 | func Train(f map[string][]tokens.Freqs) *Classifier { 17 | seenToks := map[string]bool{} 18 | sampleCount := 0 19 | for _, samples := range f { 20 | for _, sample := range samples { 21 | for tok := range sample { 22 | seenToks[tok] = true 23 | } 24 | sampleCount++ 25 | } 26 | } 27 | 28 | toks := make([]string, 0, len(seenToks)) 29 | for tok := range seenToks { 30 | toks = append(toks, tok) 31 | } 32 | 33 | samples := make([]Sample, 0, sampleCount) 34 | for lang, freqSamples := range f { 35 | for _, freqs := range freqSamples { 36 | vec := make(linalg.Vector, len(toks)) 37 | for i, token := range toks { 38 | vec[i] = freqs[token] 39 | } 40 | if mag := vec.Dot(vec); mag != 0 { 41 | vec.Scale(1 / mag) 42 | } 43 | samples = append(samples, Sample{ 44 | Language: lang, 45 | Vector: vec, 46 | }) 47 | } 48 | } 49 | 50 | kValue := optimalKValue(samples) 51 | return &Classifier{ 52 | Tokens: toks, 53 | Samples: samples, 54 | NeighborCount: kValue, 55 | } 56 | } 57 | 58 | func optimalKValue(s []Sample) int { 59 | crossCount := int(crossValidationFrac * float64(len(s))) 60 | if crossCount == 0 { 61 | return 1 62 | } 63 | samples := shuffleSamples(s) 64 | crossSamples := samples[0:crossCount] 65 | trainingSamples := samples[crossCount:] 66 | 67 | crossMatches := sortedCrossMatches(crossSamples, trainingSamples) 68 | 69 | bestK := 1 70 | bestCorrect := 0 71 | for k := 1; k <= len(trainingSamples); k++ { 72 | crossCorrect := 0 73 | for crossIdx, matches := range crossMatches { 74 | classification := dominantClassification(matches[:k]) 75 | actualLang := crossSamples[crossIdx].Language 76 | if classification == actualLang { 77 | crossCorrect++ 78 | } 79 | } 80 | if crossCorrect > bestCorrect { 81 | bestK = k 82 | bestCorrect = crossCorrect 83 | } 84 | } 85 | 86 | return bestK 87 | } 88 | 89 | func sortedCrossMatches(cross, training []Sample) [][]match { 90 | res := make([][]match, len(cross)) 91 | for i, crossSample := range cross { 92 | res[i] = make([]match, len(training)) 93 | for j, trainingSample := range training { 94 | correlation := trainingSample.Vector.Dot(crossSample.Vector) 95 | res[i][j] = match{ 96 | Language: trainingSample.Language, 97 | Correlation: correlation, 98 | } 99 | } 100 | sort.Sort(matchSorter(res[i])) 101 | } 102 | return res 103 | } 104 | 105 | func shuffleSamples(s []Sample) []Sample { 106 | res := make([]Sample, len(s)) 107 | 108 | p := rand.Perm(len(s)) 109 | for i, x := range p { 110 | res[i] = s[x] 111 | } 112 | 113 | return res 114 | } 115 | 116 | type matchSorter []match 117 | 118 | func (m matchSorter) Len() int { 119 | return len(m) 120 | } 121 | 122 | func (m matchSorter) Less(i, j int) bool { 123 | return m[i].Correlation > m[j].Correlation 124 | } 125 | 126 | func (m matchSorter) Swap(i, j int) { 127 | m[i], m[j] = m[j], m[i] 128 | } 129 | -------------------------------------------------------------------------------- /neuralnet/gradients.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import "github.com/unixpickle/num-analysis/kahan" 4 | 5 | // A gradientCalc can compute gradients of the 6 | // error function 0.5*||Actual - Expected||^2 7 | // for a neural network on a given input. 8 | // 9 | // A gradientCalc demands a lot of scratch 10 | // memory, so it is a good idea to create one 11 | // gradientCalc and then reuse it over and over. 12 | type gradientCalc struct { 13 | n *Network 14 | 15 | hiddenOutputs []float64 16 | outputs []float64 17 | inputs []float64 18 | expectedOut []float64 19 | 20 | hiddenOutPartials []float64 21 | 22 | OutputPartials [][]float64 23 | HiddenPartials [][]float64 24 | } 25 | 26 | func newGradientCalc(n *Network) *gradientCalc { 27 | res := &gradientCalc{ 28 | n: n, 29 | hiddenOutputs: make([]float64, len(n.HiddenWeights)), 30 | outputs: make([]float64, len(n.OutputWeights)), 31 | expectedOut: make([]float64, len(n.OutputWeights)), 32 | hiddenOutPartials: make([]float64, len(n.HiddenWeights)), 33 | OutputPartials: make([][]float64, len(n.OutputWeights)), 34 | HiddenPartials: make([][]float64, len(n.HiddenWeights)), 35 | } 36 | 37 | for i := range res.OutputPartials { 38 | res.OutputPartials[i] = make([]float64, len(res.hiddenOutputs)+1) 39 | } 40 | for i := range res.HiddenPartials { 41 | res.HiddenPartials[i] = make([]float64, len(n.Tokens)+1) 42 | } 43 | 44 | return res 45 | } 46 | 47 | func (g *gradientCalc) Compute(inputs []float64, langIdx int) { 48 | g.inputs = inputs 49 | for j := range g.expectedOut { 50 | if j == langIdx { 51 | g.expectedOut[j] = 1 52 | } else { 53 | g.expectedOut[j] = 0 54 | } 55 | } 56 | 57 | g.computeOutputs() 58 | g.computeGradients() 59 | } 60 | 61 | func (g *gradientCalc) computeOutputs() { 62 | outputSums := make([]*kahan.Summer64, len(g.outputs)) 63 | 64 | for i := range outputSums { 65 | outputSums[i] = kahan.NewSummer64() 66 | outputSums[i].Add(g.n.outputBias(i)) 67 | } 68 | 69 | for hiddenIndex, hiddenWeights := range g.n.HiddenWeights { 70 | hiddenSum := kahan.NewSummer64() 71 | for j, input := range g.inputs { 72 | hiddenSum.Add(input * hiddenWeights[j]) 73 | } 74 | hiddenSum.Add(g.n.hiddenBias(hiddenIndex)) 75 | 76 | hiddenOut := sigmoid(hiddenSum.Sum()) 77 | g.hiddenOutputs[hiddenIndex] = hiddenOut 78 | for j, outSum := range outputSums { 79 | weight := g.n.OutputWeights[j][hiddenIndex] 80 | outSum.Add(weight * hiddenOut) 81 | } 82 | } 83 | 84 | for i, sum := range outputSums { 85 | g.outputs[i] = sigmoid(sum.Sum()) 86 | } 87 | } 88 | 89 | func (g *gradientCalc) computeGradients() { 90 | for i, output := range g.outputs { 91 | gradient := g.OutputPartials[i] 92 | diff := output - g.expectedOut[i] 93 | sumPartial := (1 - output) * output * diff 94 | for j, input := range g.hiddenOutputs { 95 | gradient[j] = input * sumPartial 96 | g.hiddenOutPartials[j] = g.n.OutputWeights[i][j] * sumPartial 97 | } 98 | gradient[len(g.hiddenOutputs)] = sumPartial 99 | } 100 | for i, output := range g.hiddenOutputs { 101 | gradient := g.HiddenPartials[i] 102 | sumPartial := (1 - output) * output * g.hiddenOutPartials[i] 103 | for j, input := range g.inputs { 104 | gradient[j] = input * sumPartial 105 | } 106 | gradient[len(g.inputs)] = sumPartial 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /cmd/fetchlang/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | ) 10 | 11 | const MinFileSize = 100 12 | const MaxFileSize = 500000 13 | const MaxRequestsPerRepo = 20 14 | 15 | type Language struct { 16 | Name string 17 | Extensions []string 18 | } 19 | 20 | var Languages = []Language{ 21 | {"ActionScript", []string{"as"}}, 22 | {"C", []string{"c"}}, 23 | {"C#", []string{"cs"}}, 24 | {"C++", []string{"cpp", "c++", "C", "cc"}}, 25 | {"Clojure", []string{"clj"}}, 26 | {"CoffeeScript", []string{"coffee"}}, 27 | {"CSS", []string{"css"}}, 28 | {"Go", []string{"go"}}, 29 | {"Haskell", []string{"hs"}}, 30 | {"HTML", []string{"html", "htm"}}, 31 | {"Java", []string{"java"}}, 32 | {"JavaScript", []string{"js"}}, 33 | {"Lua", []string{"lua"}}, 34 | {"Matlab", []string{"m"}}, 35 | {"Objective-C", []string{"m"}}, 36 | {"Perl", []string{"pl"}}, 37 | {"PHP", []string{"php"}}, 38 | {"Python", []string{"py"}}, 39 | {"R", []string{"r", "R"}}, 40 | {"Ruby", []string{"rb"}}, 41 | {"Scala", []string{"scala"}}, 42 | {"Shell", []string{"sh", "bash"}}, 43 | {"Swift", []string{"swift"}}, 44 | {"TeX", []string{"tex", "TeX"}}, 45 | {"VimL", []string{"vim"}}, 46 | } 47 | 48 | func main() { 49 | if len(os.Args) != 3 { 50 | fmt.Fprintln(os.Stderr, "Usage: fetchlang ") 51 | os.Exit(1) 52 | } 53 | sampleDir := os.Args[1] 54 | sampleCount, err := strconv.Atoi(os.Args[2]) 55 | if err != nil { 56 | fmt.Fprintln(os.Stderr, "Invalid sample count:", os.Args[2]) 57 | os.Exit(1) 58 | } 59 | 60 | client, err := PromptGithubClient() 61 | if err != nil { 62 | fmt.Fprintln(os.Stderr, "Failed to read credentials:", err) 63 | os.Exit(1) 64 | } 65 | 66 | for _, lang := range Languages { 67 | fmt.Println("Fetching samples for", lang.Name, "...") 68 | err := fetchLanguage(client, sampleDir, sampleCount, lang) 69 | if err != nil { 70 | fmt.Println("Error:", err) 71 | } 72 | } 73 | } 74 | 75 | func fetchLanguage(github *GithubClient, sampleDir string, count int, lang Language) error { 76 | if err := os.Mkdir(filepath.Join(sampleDir, lang.Name), 0755); err != nil { 77 | return err 78 | } 79 | 80 | doneChan := make(chan struct{}, 1) 81 | repoChan, errChan := github.Search(lang.Name, doneChan) 82 | 83 | defer func() { 84 | close(doneChan) 85 | }() 86 | 87 | resCount := 0 88 | for repo := range repoChan { 89 | file, err := github.SearchFile(FileSearch{ 90 | Repository: repo, 91 | MinFileSize: MinFileSize, 92 | MaxFileSize: MaxFileSize, 93 | Extensions: lang.Extensions, 94 | MaxRequests: MaxRequestsPerRepo, 95 | }) 96 | if err == ErrNoResults { 97 | fmt.Println("No results for:", repo) 98 | continue 99 | } else if err == ErrMaxRequests { 100 | fmt.Println("Max requests exceeded:", repo) 101 | continue 102 | } else if err != nil { 103 | return err 104 | } else if file == nil { 105 | continue 106 | } 107 | 108 | fileName := fmt.Sprintf("%d.%s", resCount, lang.Extensions[0]) 109 | targetFile := filepath.Join(sampleDir, lang.Name, fileName) 110 | if err := ioutil.WriteFile(targetFile, file, 0755); err != nil { 111 | return err 112 | } 113 | 114 | resCount++ 115 | if resCount == count { 116 | break 117 | } 118 | } 119 | 120 | select { 121 | case err := <-errChan: 122 | return err 123 | default: 124 | return nil 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /tokens/counts.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | ) 7 | 8 | // Counts records the number of occurrences 9 | // of tokens in a given document. 10 | type Counts map[string]int 11 | 12 | // CountTokens counts the tokens in a document. 13 | // 14 | // Four different types of tokens are detected: 15 | // 16 | // - Heterogeneous tokens: any strings which 17 | // appear surrounded by whitespace. 18 | // - Homogeneous words: strings like "abcd" 19 | // or "123" which are one type of symbol. 20 | // - Line-initial words: both heterogeneous 21 | // and homogeneous words which begin a line. 22 | // These tokens start with "\n". 23 | // - Line-final words: both heterogeneous and 24 | // homogeneous words which end a line. 25 | // These tokens end with "\n". 26 | // 27 | // No homogeneous tokens will be counted as 28 | // heterogeneous tokens, or vice versa. 29 | // All line-boundary words are counted twice, 30 | // once with the newline and once without it. 31 | // If one token makes up an entire line, it is 32 | // counted as both line-initial and line-final. 33 | func CountTokens(contents string) Counts { 34 | res := Counts{} 35 | for _, t := range heterogeneousTokens(contents) { 36 | res[t] += 1 37 | } 38 | for _, t := range homogeneousTokens(contents) { 39 | res[t] += 1 40 | } 41 | for _, t := range lineBoundaryTokens(contents) { 42 | res[t] += 1 43 | } 44 | return res 45 | } 46 | 47 | func lineBoundaryTokens(contents string) []string { 48 | var res []string 49 | 50 | lines := strings.Split(contents, "\n") 51 | for _, line := range lines { 52 | fields := strings.Fields(line) 53 | if len(fields) == 0 { 54 | continue 55 | } 56 | for i, field := range []string{fields[0], fields[len(fields)-1]} { 57 | homog := homogeneousTokens(field) 58 | hetero := heterogeneousTokens(field) 59 | for _, tokList := range [][]string{homog, hetero} { 60 | if len(tokList) > 0 { 61 | if i == 0 { 62 | res = append(res, "\n"+tokList[0]) 63 | } else { 64 | res = append(res, tokList[len(tokList)-1]+"\n") 65 | } 66 | } 67 | } 68 | } 69 | } 70 | 71 | return res 72 | } 73 | 74 | func heterogeneousTokens(contents string) []string { 75 | fields := strings.Fields(contents) 76 | res := make([]string, 0, len(fields)) 77 | for _, f := range fields { 78 | if !isHeterogeneous(f) { 79 | res = append(res, f) 80 | } 81 | } 82 | return res 83 | } 84 | 85 | func homogeneousTokens(contents string) []string { 86 | tokens := []string{} 87 | res := "" 88 | lastClass := charClassSpace 89 | for _, ch := range contents { 90 | c := classForRune(ch) 91 | if c == lastClass { 92 | res += string(ch) 93 | continue 94 | } 95 | if lastClass != charClassSpace && len(res) > 0 { 96 | tokens = append(tokens, res) 97 | } 98 | res = string(ch) 99 | lastClass = c 100 | } 101 | if lastClass != charClassSpace && len(res) > 0 { 102 | tokens = append(tokens, res) 103 | } 104 | return tokens 105 | } 106 | 107 | type charClass int 108 | 109 | const ( 110 | charClassLetter charClass = iota 111 | charClassNumber 112 | charClassSpace 113 | charClassSymbol 114 | ) 115 | 116 | func classForRune(r rune) charClass { 117 | if unicode.IsLetter(r) { 118 | return charClassLetter 119 | } else if unicode.IsDigit(r) { 120 | return charClassNumber 121 | } else if unicode.IsSpace(r) { 122 | return charClassSpace 123 | } 124 | return charClassSymbol 125 | } 126 | 127 | func isHeterogeneous(s string) bool { 128 | if len(s) == 0 { 129 | return true 130 | } 131 | c := classForRune([]rune(s)[0]) 132 | for _, r := range s { 133 | if classForRune(r) != c { 134 | return false 135 | } 136 | } 137 | return true 138 | } 139 | -------------------------------------------------------------------------------- /cmd/fetchlang-pastie/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | 15 | "github.com/yhat/scrape" 16 | "golang.org/x/net/html" 17 | "golang.org/x/net/html/atom" 18 | ) 19 | 20 | const RoutineCount = 8 21 | 22 | func main() { 23 | if len(os.Args) != 4 { 24 | fmt.Fprintln(os.Stderr, "Usage: fetchlang-pastie ") 25 | os.Exit(1) 26 | } 27 | startIdx, err := strconv.Atoi(os.Args[1]) 28 | if err != nil { 29 | fmt.Fprintln(os.Stderr, err) 30 | os.Exit(1) 31 | } 32 | endIdx, err := strconv.Atoi(os.Args[2]) 33 | if err != nil { 34 | fmt.Fprintln(os.Stderr, err) 35 | os.Exit(1) 36 | } 37 | pasteIndices := make(chan int) 38 | go func() { 39 | for i := startIdx; i <= endIdx; i++ { 40 | pasteIndices <- i 41 | } 42 | close(pasteIndices) 43 | }() 44 | outDir := os.Args[3] 45 | if err := ensureDirectoryPresent(outDir); err != nil { 46 | fmt.Fprintln(os.Stderr, err) 47 | os.Exit(1) 48 | } 49 | fetchPastes(pasteIndices, outDir) 50 | } 51 | 52 | func fetchPastes(indices <-chan int, outDir string) { 53 | var wg sync.WaitGroup 54 | for i := 0; i < RoutineCount; i++ { 55 | wg.Add(1) 56 | go func() { 57 | defer wg.Done() 58 | for index := range indices { 59 | if err := fetchPaste(index, outDir); err != nil { 60 | fmt.Fprintln(os.Stderr, "error for paste", index, err) 61 | } else { 62 | fmt.Println("succeeded for paste", index) 63 | } 64 | } 65 | }() 66 | } 67 | wg.Wait() 68 | } 69 | 70 | func fetchPaste(index int, outDir string) error { 71 | code, lang, err := fetchPasteCode(index) 72 | if err != nil { 73 | return err 74 | } 75 | codeDir := filepath.Join(outDir, lang) 76 | ensureDirectoryPresent(codeDir) 77 | fileName := strconv.Itoa(index) + ".txt" 78 | return ioutil.WriteFile(filepath.Join(codeDir, fileName), []byte(code), 0755) 79 | } 80 | 81 | func fetchPasteCode(index int) (contents, language string, err error) { 82 | response, err := http.Get("http://pastie.org/pastes/" + strconv.Itoa(index)) 83 | if err != nil { 84 | return 85 | } 86 | body, err := ioutil.ReadAll(response.Body) 87 | response.Body.Close() 88 | if err != nil { 89 | return 90 | } 91 | pageData := string(body) 92 | exp := regexp.MustCompile("\\\n(.*?)\n ") 93 | match := exp.FindStringSubmatch(pageData) 94 | if match == nil { 95 | ioutil.WriteFile("/Users/alex/Desktop/foo.html", []byte(pageData), 0755) 96 | return "", "", errors.New("cannot locate language") 97 | } 98 | language = match[1] 99 | language = strings.Replace(language, "/", ":", -1) 100 | 101 | response, err = http.Get("http://pastie.org/pastes/" + strconv.Itoa(index) + "/text") 102 | if err != nil { 103 | return 104 | } 105 | root, err := html.Parse(response.Body) 106 | response.Body.Close() 107 | if err != nil { 108 | return 109 | } 110 | codeBlock, ok := scrape.Find(root, scrape.ByTag(atom.Pre)) 111 | if !ok { 112 | return "", "", errors.New("no
 tag")
113 | 	}
114 | 	contents = codeBlockText(codeBlock)
115 | 	return
116 | }
117 | 
118 | func codeBlockText(n *html.Node) string {
119 | 	if n.DataAtom == atom.Br {
120 | 		return "\n"
121 | 	}
122 | 	if n.Type == html.TextNode {
123 | 		return n.Data
124 | 	}
125 | 
126 | 	var res string
127 | 	child := n.FirstChild
128 | 	for child != nil {
129 | 		res += codeBlockText(child)
130 | 		child = child.NextSibling
131 | 	}
132 | 	return res
133 | }
134 | 
135 | func ensureDirectoryPresent(dirPath string) error {
136 | 	if _, err := os.Stat(dirPath); err != nil {
137 | 		if err := os.Mkdir(dirPath, 0755); err != nil {
138 | 			return err
139 | 		}
140 | 	}
141 | 	return nil
142 | }
143 | 


--------------------------------------------------------------------------------
/neuralnet/classifier.go:
--------------------------------------------------------------------------------
  1 | package neuralnet
  2 | 
  3 | import (
  4 | 	"encoding/json"
  5 | 	"math"
  6 | 
  7 | 	"github.com/unixpickle/num-analysis/kahan"
  8 | 	"github.com/unixpickle/whichlang/tokens"
  9 | )
 10 | 
 11 | // A Network is a feedforward neural network with
 12 | // a single hidden layer.
 13 | type Network struct {
 14 | 	Tokens []string
 15 | 	Langs  []string
 16 | 
 17 | 	// In the following weights, the last weight for
 18 | 	// each neuron corresponds to a constant shift,
 19 | 	// and is not multiplied by an input's value.
 20 | 	HiddenWeights [][]float64
 21 | 	OutputWeights [][]float64
 22 | 
 23 | 	// Information used to centralize the training
 24 | 	// weights around 0 and get them to have a
 25 | 	// standard deviation of 1.
 26 | 	InputShift float64
 27 | 	InputScale float64
 28 | }
 29 | 
 30 | func DecodeNetwork(data []byte) (*Network, error) {
 31 | 	var n Network
 32 | 	if err := json.Unmarshal(data, &n); err != nil {
 33 | 		return nil, err
 34 | 	}
 35 | 	return &n, nil
 36 | }
 37 | 
 38 | func (n *Network) Copy() *Network {
 39 | 	res := &Network{
 40 | 		Tokens:        make([]string, len(n.Tokens)),
 41 | 		Langs:         make([]string, len(n.Langs)),
 42 | 		HiddenWeights: make([][]float64, len(n.HiddenWeights)),
 43 | 		OutputWeights: make([][]float64, len(n.OutputWeights)),
 44 | 		InputShift:    n.InputShift,
 45 | 		InputScale:    n.InputScale,
 46 | 	}
 47 | 	copy(res.Tokens, n.Tokens)
 48 | 	copy(res.Langs, n.Langs)
 49 | 	for i, w := range n.HiddenWeights {
 50 | 		res.HiddenWeights[i] = make([]float64, len(w))
 51 | 		copy(res.HiddenWeights[i], w)
 52 | 	}
 53 | 	for i, w := range n.OutputWeights {
 54 | 		res.OutputWeights[i] = make([]float64, len(w))
 55 | 		copy(res.OutputWeights[i], w)
 56 | 	}
 57 | 	return res
 58 | }
 59 | 
 60 | func (n *Network) Classify(freqs tokens.Freqs) string {
 61 | 	inputs := n.shiftedInput(freqs)
 62 | 
 63 | 	outputSums := make([]*kahan.Summer64, len(n.OutputWeights))
 64 | 	for i := range outputSums {
 65 | 		outputSums[i] = kahan.NewSummer64()
 66 | 		outputSums[i].Add(n.outputBias(i))
 67 | 	}
 68 | 
 69 | 	for hiddenIndex, hiddenWeights := range n.HiddenWeights {
 70 | 		hiddenSum := kahan.NewSummer64()
 71 | 		for j, input := range inputs {
 72 | 			hiddenSum.Add(input * hiddenWeights[j])
 73 | 		}
 74 | 		hiddenSum.Add(n.hiddenBias(hiddenIndex))
 75 | 
 76 | 		hiddenOut := sigmoid(hiddenSum.Sum())
 77 | 		for j, outSum := range outputSums {
 78 | 			weight := n.OutputWeights[j][hiddenIndex]
 79 | 			outSum.Add(weight * hiddenOut)
 80 | 		}
 81 | 	}
 82 | 
 83 | 	maxSum := outputSums[0].Sum()
 84 | 	maxIdx := 0
 85 | 	for i, x := range outputSums {
 86 | 		if x.Sum() > maxSum {
 87 | 			maxSum = x.Sum()
 88 | 			maxIdx = i
 89 | 		}
 90 | 	}
 91 | 	return n.Langs[maxIdx]
 92 | }
 93 | 
 94 | func (n *Network) Encode() []byte {
 95 | 	enc, _ := json.Marshal(n)
 96 | 	return enc
 97 | }
 98 | 
 99 | func (n *Network) Languages() []string {
100 | 	return n.Langs
101 | }
102 | 
103 | func (n *Network) outputBias(outputIdx int) float64 {
104 | 	return n.OutputWeights[outputIdx][len(n.HiddenWeights)]
105 | }
106 | 
107 | func (n *Network) hiddenBias(hiddenIdx int) float64 {
108 | 	return n.HiddenWeights[hiddenIdx][len(n.Tokens)]
109 | }
110 | 
111 | func (n *Network) containsNaN() bool {
112 | 	for _, wss := range [][][]float64{n.HiddenWeights, n.OutputWeights} {
113 | 		for _, ws := range wss {
114 | 			for _, w := range ws {
115 | 				if math.IsNaN(w) {
116 | 					return true
117 | 				}
118 | 			}
119 | 		}
120 | 	}
121 | 	return false
122 | }
123 | 
124 | func (n *Network) shiftedInput(f tokens.Freqs) []float64 {
125 | 	res := make([]float64, len(n.Tokens))
126 | 	for i, word := range n.Tokens {
127 | 		res[i] = (f[word] + n.InputShift) * n.InputScale
128 | 	}
129 | 	return res
130 | }
131 | 
132 | func sigmoid(x float64) float64 {
133 | 	return 1.0 / (1.0 + math.Exp(-x))
134 | }
135 | 


--------------------------------------------------------------------------------
/tokens/sample_counts.go:
--------------------------------------------------------------------------------
  1 | package tokens
  2 | 
  3 | import (
  4 | 	"io/ioutil"
  5 | 	"os"
  6 | 	"path/filepath"
  7 | 	"sort"
  8 | 	"strings"
  9 | )
 10 | 
 11 | // SampleCounts maps programming languages to
 12 | // arrays of language sample documents, where
 13 | // each sample is represented by a Counts.
 14 | type SampleCounts map[string][]Counts
 15 | 
 16 | // ReadSampleCounts computes token counts
 17 | // for programming language samples in a
 18 | // directory.
 19 | //
 20 | // The directory should contain sub-directories
 21 | // for each programming language, and each of
 22 | // these languages should contain one or more
 23 | // source files.
 24 | //
 25 | // The returned map maps language names to lists
 26 | // of Counts, where each Counts corresponds to
 27 | // one source file.
 28 | func ReadSampleCounts(sampleDir string) (SampleCounts, error) {
 29 | 	languages, err := readDirectory(sampleDir, true)
 30 | 	if err != nil {
 31 | 		return nil, err
 32 | 	}
 33 | 
 34 | 	res := SampleCounts{}
 35 | 	for _, language := range languages {
 36 | 		langDir := filepath.Join(sampleDir, language)
 37 | 		files, err := readDirectory(langDir, false)
 38 | 		if err != nil {
 39 | 			return nil, err
 40 | 		}
 41 | 		for _, file := range files {
 42 | 			contents, err := ioutil.ReadFile(filepath.Join(langDir, file))
 43 | 			if err != nil {
 44 | 				return nil, err
 45 | 			}
 46 | 			counts := CountTokens(string(contents))
 47 | 			res[language] = append(res[language], counts)
 48 | 		}
 49 | 	}
 50 | 
 51 | 	return res, nil
 52 | }
 53 | 
 54 | // NumTokens returns the number of unique
 55 | // tokens in all the documents.
 56 | func (s SampleCounts) NumTokens() int {
 57 | 	toks := map[string]bool{}
 58 | 	for _, samples := range s {
 59 | 		for _, sample := range samples {
 60 | 			for word := range sample {
 61 | 				toks[word] = true
 62 | 			}
 63 | 		}
 64 | 	}
 65 | 	return len(toks)
 66 | }
 67 | 
 68 | // Prune removes tokens which appear in n
 69 | // documents or fewer.
 70 | //
 71 | // This creates a "" token in each document
 72 | // corresponding to the number of pruned
 73 | // tokens from that document.
 74 | func (s SampleCounts) Prune(n int) {
 75 | 	docCount := map[string]int{}
 76 | 	for _, samples := range s {
 77 | 		for _, sample := range samples {
 78 | 			for word := range sample {
 79 | 				docCount[word]++
 80 | 			}
 81 | 		}
 82 | 	}
 83 | 
 84 | 	remove := map[string]bool{}
 85 | 	for word, count := range docCount {
 86 | 		if count <= n {
 87 | 			remove[word] = true
 88 | 		}
 89 | 	}
 90 | 
 91 | 	for _, samples := range s {
 92 | 		for i, sample := range samples {
 93 | 			newSample := map[string]int{}
 94 | 			removed := 0
 95 | 			for word, count := range sample {
 96 | 				if !remove[word] {
 97 | 					newSample[word] = count
 98 | 				} else {
 99 | 					removed += count
100 | 				}
101 | 			}
102 | 			if removed > 0 {
103 | 				newSample[""] += removed
104 | 			}
105 | 			samples[i] = newSample
106 | 		}
107 | 	}
108 | }
109 | 
110 | // SampleFreqs converts every Counts object
111 | // in s into a Freqs object.
112 | // The "" key in each Freqs object is deleted
113 | // if one exists.
114 | func (s SampleCounts) SampleFreqs() map[string][]Freqs {
115 | 	res := map[string][]Freqs{}
116 | 	for lang, samples := range s {
117 | 		for _, sample := range samples {
118 | 			f := sample.Freqs()
119 | 			delete(f, "")
120 | 			res[lang] = append(res[lang], f)
121 | 		}
122 | 	}
123 | 	return res
124 | }
125 | 
126 | func readDirectory(dir string, isDir bool) ([]string, error) {
127 | 	f, err := os.Open(dir)
128 | 	if err != nil {
129 | 		return nil, err
130 | 	}
131 | 	defer f.Close()
132 | 	contents, err := f.Readdir(-1)
133 | 	if err != nil {
134 | 		return nil, err
135 | 	}
136 | 	res := make([]string, 0, len(contents))
137 | 	for _, info := range contents {
138 | 		if info.IsDir() == isDir && !strings.HasPrefix(info.Name(), ".") {
139 | 			res = append(res, info.Name())
140 | 		}
141 | 	}
142 | 	sort.Strings(res)
143 | 	return res, nil
144 | }
145 | 


--------------------------------------------------------------------------------
/cmd/fetchlang/file_search.go:
--------------------------------------------------------------------------------
  1 | package main
  2 | 
  3 | import (
  4 | 	"encoding/base64"
  5 | 	"encoding/json"
  6 | 	"errors"
  7 | 	"net/url"
  8 | 	"path"
  9 | 	"strings"
 10 | )
 11 | 
 12 | var (
 13 | 	ErrNoResults   = errors.New("no results")
 14 | 	ErrMaxRequests = errors.New("too many requests")
 15 | )
 16 | 
 17 | // A FileSearch defines parameters for
 18 | // searching a repository for files.
 19 | type FileSearch struct {
 20 | 	// Repository is the repository name,
 21 | 	// formatted as "user/repo".
 22 | 	Repository string
 23 | 
 24 | 	MinFileSize int
 25 | 	MaxFileSize int
 26 | 	Extensions  []string
 27 | 
 28 | 	// MaxRequests is the maximum number of
 29 | 	// API requests to be performed by the
 30 | 	// search before giving up.
 31 | 	MaxRequests int
 32 | }
 33 | 
 34 | // SearchFile runs a FileSearch.
 35 | // It returns ErrMaxRequests if more than
 36 | // s.MaxRequests requests are used.
 37 | // It returns ErrNoResults if no results
 38 | // are found.
 39 | func (g *GithubClient) SearchFile(s FileSearch) (contents []byte, err error) {
 40 | 	return g.firstFileSearch(&s, "/")
 41 | }
 42 | 
 43 | func (g *GithubClient) firstFileSearch(s *FileSearch, dir string) (match []byte, err error) {
 44 | 	if s.MaxRequests == 0 {
 45 | 		return nil, ErrMaxRequests
 46 | 	}
 47 | 
 48 | 	u := url.URL{
 49 | 		Scheme: "https",
 50 | 		Host:   "api.github.com",
 51 | 		Path:   path.Join("/repos", s.Repository, "/contents", dir),
 52 | 	}
 53 | 
 54 | 	body, _, err := g.request(u.String())
 55 | 	if err != nil {
 56 | 		return nil, err
 57 | 	}
 58 | 
 59 | 	s.MaxRequests--
 60 | 
 61 | 	var result []entity
 62 | 	if err := json.Unmarshal(body, &result); err != nil {
 63 | 		return nil, err
 64 | 	}
 65 | 
 66 | 	for _, ent := range result {
 67 | 		if ent.Match(s) {
 68 | 			return g.readFile(s.Repository, ent.Path)
 69 | 		}
 70 | 	}
 71 | 
 72 | 	sourceDirectoryHeuristic(result, s.Repository)
 73 | 
 74 | 	for _, ent := range result {
 75 | 		if ent.Dir() {
 76 | 			match, err = g.firstFileSearch(s, ent.Path)
 77 | 			if match != nil || (err != nil && err != ErrNoResults) {
 78 | 				return
 79 | 			}
 80 | 		}
 81 | 	}
 82 | 
 83 | 	return nil, ErrNoResults
 84 | }
 85 | 
 86 | func (g *GithubClient) readFile(repo, filePath string) ([]byte, error) {
 87 | 	u := url.URL{
 88 | 		Scheme: "https",
 89 | 		Host:   "api.github.com",
 90 | 		Path:   path.Join("/repos", repo, "/contents", filePath),
 91 | 	}
 92 | 	body, _, err := g.request(u.String())
 93 | 	if err != nil {
 94 | 		return nil, err
 95 | 	}
 96 | 
 97 | 	var result struct {
 98 | 		Content  string `json:"content"`
 99 | 		Encoding string `json:"encoding"`
100 | 	}
101 | 	if err := json.Unmarshal(body, &result); err != nil {
102 | 		return nil, err
103 | 	}
104 | 
105 | 	if result.Encoding == "base64" {
106 | 		return base64.StdEncoding.DecodeString(result.Content)
107 | 	} else {
108 | 		return nil, errors.New("unknown encoding: " + result.Encoding)
109 | 	}
110 | }
111 | 
112 | type entity struct {
113 | 	Name string `json:"name"`
114 | 	Path string `json:"path"`
115 | 	Size int    `json:"size"`
116 | 	Type string `json:"type"`
117 | }
118 | 
119 | func (e *entity) Dir() bool {
120 | 	return e.Type == "dir"
121 | }
122 | 
123 | func (e *entity) Match(s *FileSearch) bool {
124 | 	if e.Type != "file" {
125 | 		return false
126 | 	}
127 | 	if e.Size < s.MinFileSize || e.Size > s.MaxFileSize {
128 | 		return false
129 | 	}
130 | 	for _, ext := range s.Extensions {
131 | 		if strings.HasSuffix(e.Name, "."+ext) {
132 | 			return true
133 | 		}
134 | 	}
135 | 	return false
136 | }
137 | 
138 | // sourceDirectoryHeuristic sorts a list of
139 | // entities so that the first ones are more
140 | // likely to contain source code.
141 | func sourceDirectoryHeuristic(results []entity, repoName string) {
142 | 	sourceDirs := []string{"src", repoName, "lib", "com", "org", "net", "css", "assets"}
143 | 	numFound := 0
144 | 	for _, sourceDir := range sourceDirs {
145 | 		for i, ent := range results[numFound:] {
146 | 			if ent.Dir() && ent.Name == sourceDir {
147 | 				results[numFound], results[i] = results[i], results[numFound]
148 | 				numFound++
149 | 				break
150 | 			}
151 | 		}
152 | 	}
153 | }
154 | 


--------------------------------------------------------------------------------
/neuralnet/data_set.go:
--------------------------------------------------------------------------------
  1 | package neuralnet
  2 | 
  3 | import (
  4 | 	"math"
  5 | 	"math/rand"
  6 | 	"sort"
  7 | 
  8 | 	"github.com/unixpickle/num-analysis/kahan"
  9 | 	"github.com/unixpickle/whichlang/tokens"
 10 | )
 11 | 
 12 | const ValidationFraction = 0.3
 13 | 
 14 | // A DataSet is a set of data split into training
 15 | // samples and validation samples.
 16 | type DataSet struct {
 17 | 	ValidationSamples     map[string][]tokens.Freqs
 18 | 	TrainingSamples       map[string][]tokens.Freqs
 19 | 	NormalTrainingSamples map[string][][]float64
 20 | 
 21 | 	// These are statistical properties of the
 22 | 	// training samples' frequency values.
 23 | 	MeanFrequency   float64
 24 | 	FrequencyStddev float64
 25 | }
 26 | 
 27 | // NewDataSet creates a DataSet by randomly
 28 | // partitioning some data samples into
 29 | // validation and training samples.
 30 | func NewDataSet(samples map[string][]tokens.Freqs) *DataSet {
 31 | 	res := &DataSet{
 32 | 		ValidationSamples: map[string][]tokens.Freqs{},
 33 | 		TrainingSamples:   map[string][]tokens.Freqs{},
 34 | 	}
 35 | 	for lang, langSamples := range samples {
 36 | 		shuffled := make([]tokens.Freqs, len(langSamples))
 37 | 		perm := rand.Perm(len(shuffled))
 38 | 		for i, x := range perm {
 39 | 			shuffled[i] = langSamples[x]
 40 | 		}
 41 | 
 42 | 		numValid := int(float64(len(langSamples)) * ValidationFraction)
 43 | 		res.ValidationSamples[lang] = shuffled[:numValid]
 44 | 		res.TrainingSamples[lang] = shuffled[numValid:]
 45 | 	}
 46 | 
 47 | 	res.computeStatistics()
 48 | 	res.computeNormalSamples()
 49 | 
 50 | 	return res
 51 | }
 52 | 
 53 | // CrossScore returns the fraction of withheld
 54 | // samples the Network worked for.
 55 | func (c *DataSet) CrossScore(n *Network) float64 {
 56 | 	return scoreNetwork(n, c.ValidationSamples)
 57 | }
 58 | 
 59 | // TrainingScore returns the fraction of
 60 | // training samples the Network worked for.
 61 | func (c *DataSet) TrainingScore(n *Network) float64 {
 62 | 	return scoreNetwork(n, c.TrainingSamples)
 63 | }
 64 | 
 65 | // Tokens returns all of the tokens from all
 66 | // of the training samples.
 67 | func (c *DataSet) Tokens() []string {
 68 | 	toks := map[string]bool{}
 69 | 	for _, samples := range c.TrainingSamples {
 70 | 		for _, sample := range samples {
 71 | 			for tok := range sample {
 72 | 				toks[tok] = true
 73 | 			}
 74 | 		}
 75 | 	}
 76 | 
 77 | 	res := make([]string, 0, len(toks))
 78 | 	for tok := range toks {
 79 | 		res = append(res, tok)
 80 | 	}
 81 | 	sort.Strings(res)
 82 | 	return res
 83 | }
 84 | 
 85 | // Langs returns all of the languages represented
 86 | // by the training samples.
 87 | func (c *DataSet) Langs() []string {
 88 | 	res := make([]string, 0, len(c.TrainingSamples))
 89 | 	for lang := range c.TrainingSamples {
 90 | 		res = append(res, lang)
 91 | 	}
 92 | 	sort.Strings(res)
 93 | 	return res
 94 | }
 95 | 
 96 | func (c *DataSet) computeStatistics() {
 97 | 	tokens := c.Tokens()
 98 | 
 99 | 	freqSum := kahan.NewSummer64()
100 | 	freqCount := 0
101 | 	for _, langSamples := range c.TrainingSamples {
102 | 		for _, sample := range langSamples {
103 | 			freqCount += len(tokens)
104 | 			for _, freq := range sample {
105 | 				freqSum.Add(freq)
106 | 			}
107 | 		}
108 | 	}
109 | 
110 | 	c.MeanFrequency = freqSum.Sum() / float64(freqCount)
111 | 
112 | 	variationSum := kahan.NewSummer64()
113 | 	for _, langSamples := range c.TrainingSamples {
114 | 		for _, sample := range langSamples {
115 | 			for _, token := range tokens {
116 | 				freq := sample[token]
117 | 				variationSum.Add(math.Pow(freq-c.MeanFrequency, 2))
118 | 			}
119 | 		}
120 | 	}
121 | 
122 | 	c.FrequencyStddev = math.Sqrt(variationSum.Sum() / float64(freqCount))
123 | }
124 | 
125 | func (c *DataSet) computeNormalSamples() {
126 | 	c.NormalTrainingSamples = map[string][][]float64{}
127 | 	tokens := c.Tokens()
128 | 
129 | 	for lang, langSamples := range c.TrainingSamples {
130 | 		sampleList := make([][]float64, len(langSamples))
131 | 		for i, sample := range langSamples {
132 | 			sampleVec := make([]float64, len(tokens))
133 | 			for j, token := range tokens {
134 | 				sampleVec[j] = (sample[token] - c.MeanFrequency) / c.FrequencyStddev
135 | 			}
136 | 			sampleList[i] = sampleVec
137 | 		}
138 | 		c.NormalTrainingSamples[lang] = sampleList
139 | 	}
140 | }
141 | 
142 | func scoreNetwork(n *Network, samples map[string][]tokens.Freqs) float64 {
143 | 	var totalRight int
144 | 	var total int
145 | 	for lang, langSamples := range samples {
146 | 		for _, sample := range langSamples {
147 | 			if n.Classify(sample) == lang {
148 | 				totalRight++
149 | 			}
150 | 			total++
151 | 		}
152 | 	}
153 | 	return float64(totalRight) / float64(total)
154 | }
155 | 


--------------------------------------------------------------------------------
/svm/trainer_params.go:
--------------------------------------------------------------------------------
  1 | package svm
  2 | 
  3 | import (
  4 | 	"errors"
  5 | 	"os"
  6 | 	"strconv"
  7 | )
  8 | 
  9 | const (
 10 | 	defaultTradeoff                = 1e-5
 11 | 	defaultCrossValidationFraction = 0.3
 12 | )
 13 | 
 14 | var (
 15 | 	defaultRBFParams  = [][]float64{{1e-5}, {1e-4}, {1e-3}, {1e-2}, {1e-1}, {1e0}, {1e1}, {1e2}}
 16 | 	defaultPolyPowers = []float64{2}
 17 | 	defaultPolySums   = []float64{0, 1}
 18 | )
 19 | 
 20 | // These environment variables specify
 21 | // various parameters for the SVM trainer.
 22 | const (
 23 | 	// Set this to "1" to get verbose logs.
 24 | 	VerboseEnvVar = "SVM_VERBOSE"
 25 | 
 26 | 	// You may set this to "linear", "rbf", or
 27 | 	// "polynomial".
 28 | 	KernelEnvVar = "SVM_KERNEL"
 29 | 
 30 | 	// The numerical constant used in the
 31 | 	// RBF kernel.
 32 | 	RBFParamEnvVar = "SVM_RBF_PARAM"
 33 | 
 34 | 	// The degree parameter for polynomial kernels.
 35 | 	PolyDegreeEnvVar = "SVM_POLY_DEGREE"
 36 | 
 37 | 	// The summed term (before applying the exponential)
 38 | 	// for polynomial kernels.
 39 | 	PolySumEnvVar = "SVM_POLY_SUM"
 40 | 
 41 | 	// The tradeoff between margin size and hinge loss.
 42 | 	// The higher the tradeoff value, the greater the
 43 | 	// margin size, but at the expense of correct
 44 | 	// classifications.
 45 | 	TradeoffEnvVar = "SVM_TRADEOFF"
 46 | 
 47 | 	// The fraction (from 0-1) of samples which are
 48 | 	// used for cross validation.
 49 | 	CrossValidationEnvVar = "SVM_CROSS_VALIDATION"
 50 | )
 51 | 
 52 | // TrainerParams specifies parameters for the
 53 | // SVM trainer.
 54 | type TrainerParams struct {
 55 | 	Verbose  bool
 56 | 	Kernels  []*Kernel
 57 | 	Tradeoff float64
 58 | 
 59 | 	CrossValidation float64
 60 | }
 61 | 
 62 | // EnvTrainerParams generates TrainerParams
 63 | // by reading environment variables.
 64 | // If an environment variable is incorrectly
 65 | // formatted, this returns an error.
 66 | // When a variable is missing, a default value
 67 | // or set of values will be used.
 68 | func EnvTrainerParams() (*TrainerParams, error) {
 69 | 	var res TrainerParams
 70 | 	var err error
 71 | 
 72 | 	if res.Tradeoff, err = envTradeoff(); err != nil {
 73 | 		return nil, err
 74 | 	}
 75 | 	if res.CrossValidation, err = envCrossValidation(); err != nil {
 76 | 		return nil, err
 77 | 	}
 78 | 	res.Verbose = (os.Getenv(VerboseEnvVar) == "1")
 79 | 
 80 | 	kernTypes, err := envKernelTypes()
 81 | 	if err != nil {
 82 | 		return nil, err
 83 | 	}
 84 | 
 85 | 	for _, kernType := range kernTypes {
 86 | 		params, err := envKernelParams(kernType)
 87 | 		if err != nil {
 88 | 			return nil, err
 89 | 		}
 90 | 		for _, param := range params {
 91 | 			kernel := &Kernel{
 92 | 				Type:   kernType,
 93 | 				Params: param,
 94 | 			}
 95 | 			res.Kernels = append(res.Kernels, kernel)
 96 | 		}
 97 | 	}
 98 | 
 99 | 	return &res, nil
100 | }
101 | 
102 | func envTradeoff() (float64, error) {
103 | 	if val := os.Getenv(TradeoffEnvVar); val != "" {
104 | 		return strconv.ParseFloat(val, 64)
105 | 	} else {
106 | 		return defaultTradeoff, nil
107 | 	}
108 | }
109 | 
110 | func envCrossValidation() (float64, error) {
111 | 	if val := os.Getenv(CrossValidationEnvVar); val != "" {
112 | 		return strconv.ParseFloat(val, 64)
113 | 	} else {
114 | 		return defaultCrossValidationFraction, nil
115 | 	}
116 | }
117 | 
118 | func envKernelTypes() ([]KernelType, error) {
119 | 	if val := os.Getenv(KernelEnvVar); val != "" {
120 | 		res, ok := map[string]KernelType{
121 | 			"linear":     LinearKernel,
122 | 			"polynomial": PolynomialKernel,
123 | 			"rbf":        RadialBasisKernel,
124 | 		}[val]
125 | 		if !ok {
126 | 			return nil, errors.New("unknown kernel: " + val)
127 | 		} else {
128 | 			return []KernelType{res}, nil
129 | 		}
130 | 	} else {
131 | 		return []KernelType{LinearKernel, PolynomialKernel, RadialBasisKernel}, nil
132 | 	}
133 | }
134 | 
135 | func envKernelParams(t KernelType) ([][]float64, error) {
136 | 	switch t {
137 | 	case LinearKernel:
138 | 		return [][]float64{{}}, nil
139 | 	case RadialBasisKernel:
140 | 		if val := os.Getenv(RBFParamEnvVar); val != "" {
141 | 			res, err := strconv.ParseFloat(val, 64)
142 | 			if err != nil {
143 | 				return nil, errors.New("invalid RBF param: " + val)
144 | 			}
145 | 			return [][]float64{{res}}, nil
146 | 		} else {
147 | 			return defaultRBFParams, nil
148 | 		}
149 | 	case PolynomialKernel:
150 | 		powers := defaultPolyPowers
151 | 		sums := defaultPolySums
152 | 		if val := os.Getenv(PolySumEnvVar); val != "" {
153 | 			sum, err := strconv.ParseFloat(val, 64)
154 | 			if err != nil {
155 | 				return nil, errors.New("invalid poly sum: " + val)
156 | 			}
157 | 			sums = []float64{sum}
158 | 		}
159 | 		if val := os.Getenv(PolyDegreeEnvVar); val != "" {
160 | 			degree, err := strconv.ParseFloat(val, 64)
161 | 			if err != nil {
162 | 				return nil, errors.New("invalid poly degree: " + val)
163 | 			}
164 | 			powers = []float64{degree}
165 | 		}
166 | 		res := make([][]float64, 0, len(powers)*len(sums))
167 | 		for _, power := range powers {
168 | 			for _, sum := range sums {
169 | 				res = append(res, []float64{sum, power})
170 | 			}
171 | 		}
172 | 		return res, nil
173 | 	default:
174 | 		panic("unknown kernel: " + strconv.Itoa(int(t)))
175 | 	}
176 | }
177 | 


--------------------------------------------------------------------------------
/neuralnet/train.go:
--------------------------------------------------------------------------------
  1 | package neuralnet
  2 | 
  3 | import (
  4 | 	"log"
  5 | 	"math/rand"
  6 | 
  7 | 	"github.com/unixpickle/whichlang/tokens"
  8 | )
  9 | 
 10 | const InitialIterationCount = 200
 11 | 
 12 | func Train(data map[string][]tokens.Freqs) *Network {
 13 | 	ds := NewDataSet(data)
 14 | 
 15 | 	var best *Network
 16 | 	var bestCrossScore float64
 17 | 	var bestTrainScore float64
 18 | 
 19 | 	verbose := verboseFlag()
 20 | 
 21 | 	for _, stepSize := range stepSizes() {
 22 | 		if verbose {
 23 | 			log.Printf("trying step size %f", stepSize)
 24 | 		}
 25 | 
 26 | 		t := NewTrainer(ds, stepSize, verbose)
 27 | 		t.Train(maxIterations())
 28 | 
 29 | 		n := t.Network()
 30 | 		if n.containsNaN() {
 31 | 			if verbose {
 32 | 				log.Printf("got NaN for step size %f", stepSize)
 33 | 			}
 34 | 			continue
 35 | 		}
 36 | 		crossScore := ds.CrossScore(n)
 37 | 		trainScore := ds.TrainingScore(n)
 38 | 		if verbose {
 39 | 			log.Printf("stepSize=%f crossScore=%f trainScore=%f", stepSize,
 40 | 				crossScore, trainScore)
 41 | 		}
 42 | 		if (crossScore == bestCrossScore && trainScore >= bestTrainScore) ||
 43 | 			best == nil || (crossScore > bestCrossScore) {
 44 | 			bestCrossScore = crossScore
 45 | 			bestTrainScore = trainScore
 46 | 			best = n
 47 | 		}
 48 | 	}
 49 | 
 50 | 	return best
 51 | }
 52 | 
 53 | type Trainer struct {
 54 | 	n *Network
 55 | 	d *DataSet
 56 | 	g *gradientCalc
 57 | 
 58 | 	stepSize float64
 59 | 	verbose  bool
 60 | }
 61 | 
 62 | func NewTrainer(d *DataSet, stepSize float64, verbose bool) *Trainer {
 63 | 	hiddenCount := hiddenSize(len(d.TrainingSamples))
 64 | 	n := &Network{
 65 | 		Tokens:        d.Tokens(),
 66 | 		Langs:         d.Langs(),
 67 | 		HiddenWeights: make([][]float64, hiddenCount),
 68 | 		OutputWeights: make([][]float64, len(d.TrainingSamples)),
 69 | 		InputShift:    -d.MeanFrequency,
 70 | 		InputScale:    1 / d.FrequencyStddev,
 71 | 	}
 72 | 	for i := range n.OutputWeights {
 73 | 		n.OutputWeights[i] = make([]float64, hiddenCount+1)
 74 | 		for j := range n.OutputWeights[i] {
 75 | 			n.OutputWeights[i][j] = rand.Float64()*2 - 1
 76 | 		}
 77 | 	}
 78 | 	for i := range n.HiddenWeights {
 79 | 		n.HiddenWeights[i] = make([]float64, len(n.Tokens)+1)
 80 | 		for j := range n.HiddenWeights[i] {
 81 | 			n.HiddenWeights[i][j] = rand.Float64()*2 - 1
 82 | 		}
 83 | 	}
 84 | 	return &Trainer{
 85 | 		n:        n,
 86 | 		d:        d,
 87 | 		g:        newGradientCalc(n),
 88 | 		stepSize: stepSize,
 89 | 		verbose:  verbose,
 90 | 	}
 91 | }
 92 | 
 93 | func (t *Trainer) Train(maxIters int) {
 94 | 	iters := InitialIterationCount
 95 | 	if iters > maxIters {
 96 | 		iters = maxIters
 97 | 	}
 98 | 	for i := 0; i < iters; i++ {
 99 | 		if verboseStepsFlag() {
100 | 			log.Printf("done %d iterations, cross=%f training=%f",
101 | 				i, t.d.CrossScore(t.n), t.d.TrainingScore(t.n))
102 | 		}
103 | 		t.runAllSamples()
104 | 	}
105 | 	if iters == maxIters {
106 | 		return
107 | 	}
108 | 
109 | 	if t.n.containsNaN() {
110 | 		return
111 | 	}
112 | 
113 | 	// Use cross-validation to find the best
114 | 	// number of iterations.
115 | 	crossScore := t.d.CrossScore(t.n)
116 | 	trainScore := t.d.TrainingScore(t.n)
117 | 	lastNet := t.n.Copy()
118 | 
119 | 	for {
120 | 		if t.verbose {
121 | 			log.Printf("current scores: cross=%f train=%f iters=%d",
122 | 				crossScore, trainScore, iters)
123 | 		}
124 | 
125 | 		nextAmount := iters
126 | 		if nextAmount+iters > maxIters {
127 | 			nextAmount = maxIters - iters
128 | 		}
129 | 		for i := 0; i < nextAmount; i++ {
130 | 			if verboseStepsFlag() {
131 | 				log.Printf("done %d iterations, cross=%f training=%f",
132 | 					i+iters, t.d.CrossScore(t.n), t.d.TrainingScore(t.n))
133 | 			}
134 | 			t.runAllSamples()
135 | 			if t.n.containsNaN() {
136 | 				break
137 | 			}
138 | 		}
139 | 		iters += nextAmount
140 | 
141 | 		if t.n.containsNaN() {
142 | 			t.n = lastNet
143 | 			break
144 | 		}
145 | 
146 | 		newCrossScore := t.d.CrossScore(t.n)
147 | 		newTrainScore := t.d.TrainingScore(t.n)
148 | 		if (newCrossScore == crossScore && newTrainScore == trainScore) ||
149 | 			newCrossScore < crossScore {
150 | 			t.n = lastNet
151 | 			return
152 | 		}
153 | 
154 | 		crossScore = newCrossScore
155 | 		trainScore = newTrainScore
156 | 
157 | 		if iters == maxIters {
158 | 			return
159 | 		}
160 | 		lastNet = t.n.Copy()
161 | 	}
162 | }
163 | 
164 | func (t *Trainer) Network() *Network {
165 | 	return t.n
166 | }
167 | 
168 | func (t *Trainer) runAllSamples() {
169 | 	var samples []struct {
170 | 		LangIdx int
171 | 		Sample  []float64
172 | 	}
173 | 
174 | 	for i, lang := range t.n.Langs {
175 | 		var sample struct {
176 | 			LangIdx int
177 | 			Sample  []float64
178 | 		}
179 | 		sample.LangIdx = i
180 | 
181 | 		trainingSamples := t.d.NormalTrainingSamples[lang]
182 | 		for _, s := range trainingSamples {
183 | 			sample.Sample = s
184 | 			samples = append(samples, sample)
185 | 		}
186 | 	}
187 | 
188 | 	perm := rand.Perm(len(samples))
189 | 	for _, i := range perm {
190 | 		t.descendSample(samples[i].Sample, samples[i].LangIdx)
191 | 	}
192 | }
193 | 
194 | // descendSample performs gradient descent to
195 | // reduce the output error for a given sample.
196 | func (t *Trainer) descendSample(inputs []float64, langIdx int) {
197 | 	t.g.Compute(inputs, langIdx)
198 | 
199 | 	for i, partials := range t.g.HiddenPartials {
200 | 		for j, partial := range partials {
201 | 			t.n.HiddenWeights[i][j] -= partial * t.stepSize
202 | 		}
203 | 	}
204 | 	for i, partials := range t.g.OutputPartials {
205 | 		for j, partial := range partials {
206 | 			t.n.OutputWeights[i][j] -= partial * t.stepSize
207 | 		}
208 | 	}
209 | }
210 | 


--------------------------------------------------------------------------------
/idtree/train.go:
--------------------------------------------------------------------------------
  1 | package idtree
  2 | 
  3 | import (
  4 | 	"math"
  5 | 	"runtime"
  6 | 	"sort"
  7 | 
  8 | 	"github.com/unixpickle/whichlang/tokens"
  9 | )
 10 | 
 11 | type splitInfo struct {
 12 | 	TokenIdx  int
 13 | 	Threshold float64
 14 | 	Entropy   float64
 15 | }
 16 | 
 17 | // Train returns a *Classifier which is the
 18 | // result of running ID3 on a set of training
 19 | // samples.
 20 | func Train(freqs map[string][]tokens.Freqs) *Classifier {
 21 | 	toks := allTokens(freqs)
 22 | 	samples := freqsToLinearSamples(toks, freqs)
 23 | 	return generateClassifier(toks, samples)
 24 | }
 25 | 
 26 | func allTokens(freqs map[string][]tokens.Freqs) []string {
 27 | 	words := make([]string, 0)
 28 | 	seenWords := map[string]bool{}
 29 | 	for _, freqsList := range freqs {
 30 | 		for _, freqs := range freqsList {
 31 | 			for word := range freqs {
 32 | 				if !seenWords[word] {
 33 | 					seenWords[word] = true
 34 | 					words = append(words, word)
 35 | 				}
 36 | 			}
 37 | 		}
 38 | 	}
 39 | 	return words
 40 | }
 41 | 
 42 | // generateClassifier generates a classifier
 43 | // for the given set of samples.
 44 | func generateClassifier(toks []string, s []linearSample) *Classifier {
 45 | 	tokIdx, thresh := bestDecision(s)
 46 | 	if tokIdx == -1 {
 47 | 		lang := languageMajority(s)
 48 | 		return &Classifier{
 49 | 			LeafClassification: &lang,
 50 | 		}
 51 | 	}
 52 | 	res := &Classifier{
 53 | 		Keyword:   toks[tokIdx],
 54 | 		Threshold: thresh,
 55 | 	}
 56 | 	f, t := splitData(s, tokIdx, thresh)
 57 | 	res.FalseBranch = generateClassifier(toks, f)
 58 | 	res.TrueBranch = generateClassifier(toks, t)
 59 | 	return res
 60 | }
 61 | 
 62 | func splitData(s []linearSample, tokIdx int, thresh float64) (f, t []linearSample) {
 63 | 	f = make([]linearSample, 0, len(s))
 64 | 	t = make([]linearSample, 0, len(s))
 65 | 
 66 | 	for _, sample := range s {
 67 | 		if sample.freqs[tokIdx] > thresh {
 68 | 			t = append(t, sample)
 69 | 		} else {
 70 | 			f = append(f, sample)
 71 | 		}
 72 | 	}
 73 | 
 74 | 	return
 75 | }
 76 | 
 77 | // bestDecision returns the token and threshold
 78 | // which split the samples optimally (by the
 79 | // criterion of entropy).
 80 | // If no split exists, this returns (-1, -1).
 81 | func bestDecision(s []linearSample) (tokIdx int, thresh float64) {
 82 | 	if len(s) == 0 {
 83 | 		return -1, -1
 84 | 	}
 85 | 
 86 | 	maxProcs := runtime.GOMAXPROCS(0)
 87 | 	tokenCount := len(s[0].freqs)
 88 | 
 89 | 	toksPerGo := tokenCount / maxProcs
 90 | 	splitChan := make(chan *splitInfo, maxProcs)
 91 | 	for i := 0; i < maxProcs; i++ {
 92 | 		tokCount := toksPerGo
 93 | 		tokStart := toksPerGo * i
 94 | 
 95 | 		// The last set might need to be slightly larger
 96 | 		// due to division truncation.
 97 | 		if i == maxProcs-1 {
 98 | 			tokCount = tokenCount - tokStart
 99 | 		}
100 | 
101 | 		go bestNodeSubset(tokStart, tokCount, s, splitChan)
102 | 	}
103 | 
104 | 	var best *splitInfo
105 | 	for i := 0; i < maxProcs; i++ {
106 | 		res := <-splitChan
107 | 		if res == nil {
108 | 			continue
109 | 		}
110 | 		if best == nil || res.Entropy < best.Entropy {
111 | 			best = res
112 | 		}
113 | 	}
114 | 
115 | 	if best == nil {
116 | 		return -1, -1
117 | 	}
118 | 
119 | 	return best.TokenIdx, best.Threshold
120 | }
121 | 
122 | func bestNodeSubset(startIdx, count int, s []linearSample, res chan<- *splitInfo) {
123 | 	bestThresh := -1.0
124 | 	var bestEntropy float64
125 | 	var bestIdx int
126 | 	for i := 0; i < count; i++ {
127 | 		idx := startIdx + i
128 | 		thresh, entropy := bestSplit(s, idx)
129 | 		if thresh < 0 {
130 | 			continue
131 | 		} else if bestThresh < 0 || entropy < bestEntropy {
132 | 			bestEntropy = entropy
133 | 			bestThresh = thresh
134 | 			bestIdx = idx
135 | 		}
136 | 	}
137 | 	if bestThresh == -1 {
138 | 		res <- nil
139 | 	} else {
140 | 		res <- &splitInfo{bestIdx, bestThresh, bestEntropy}
141 | 	}
142 | }
143 | 
144 | // bestSplit finds the ideal threshold for splitting
145 | // samples by a given token (specified by an index).
146 | // This returns the threshold and the resulting entropy.
147 | // The threshold will be -1 if no split is useful.
148 | func bestSplit(unsorted []linearSample, tokenIdx int) (thresh float64, entrop float64) {
149 | 	samples := make([]linearSample, len(unsorted))
150 | 	copy(samples, unsorted)
151 | 	sorter := &sampleSorter{samples, tokenIdx}
152 | 	sort.Sort(sorter)
153 | 
154 | 	lowerDistribution := map[string]int{}
155 | 	upperDistribution := map[string]int{}
156 | 
157 | 	for _, sample := range samples {
158 | 		upperDistribution[sample.lang]++
159 | 	}
160 | 
161 | 	if len(upperDistribution) == 1 {
162 | 		// Can't split homogeneous data effectively.
163 | 		return -1, -1
164 | 	}
165 | 
166 | 	thresh = -1
167 | 	entrop = -1
168 | 
169 | 	if len(samples) == 0 {
170 | 		return
171 | 	}
172 | 
173 | 	lastFreq := samples[0].freqs[tokenIdx]
174 | 	for i := 1; i < len(samples); i++ {
175 | 		upperDistribution[samples[i-1].lang]--
176 | 		lowerDistribution[samples[i-1].lang]++
177 | 
178 | 		freq := samples[i].freqs[tokenIdx]
179 | 		if freq == lastFreq {
180 | 			continue
181 | 		}
182 | 
183 | 		upperFrac := float64(len(samples)-i) / float64(len(samples))
184 | 		lowerFrac := float64(i) / float64(len(samples))
185 | 		disorder := upperFrac*distributionEntropy(upperDistribution) +
186 | 			lowerFrac*distributionEntropy(lowerDistribution)
187 | 		if disorder < entrop || thresh == -1 {
188 | 			entrop = disorder
189 | 			thresh = (lastFreq + freq) / 2
190 | 		}
191 | 
192 | 		lastFreq = freq
193 | 	}
194 | 
195 | 	return
196 | }
197 | 
198 | func distributionEntropy(dist map[string]int) float64 {
199 | 	var res float64
200 | 	var totalCount int
201 | 	for _, count := range dist {
202 | 		totalCount += count
203 | 	}
204 | 	for _, count := range dist {
205 | 		fraction := float64(count) / float64(totalCount)
206 | 		if fraction != 0 {
207 | 			res -= math.Log(fraction) * fraction
208 | 		}
209 | 	}
210 | 	return res
211 | }
212 | 


--------------------------------------------------------------------------------
/svm/trainer.go:
--------------------------------------------------------------------------------
  1 | package svm
  2 | 
  3 | import (
  4 | 	"log"
  5 | 	"math/rand"
  6 | 	"time"
  7 | 
  8 | 	"github.com/unixpickle/num-analysis/linalg"
  9 | 	"github.com/unixpickle/weakai/svm"
 10 | 	"github.com/unixpickle/whichlang/tokens"
 11 | )
 12 | 
 13 | const farAwayTimeout = time.Hour * 24 * 365
 14 | 
 15 | func Train(data map[string][]tokens.Freqs) *Classifier {
 16 | 	params, err := EnvTrainerParams()
 17 | 	if err != nil {
 18 | 		panic(err)
 19 | 	}
 20 | 	return TrainParams(data, params)
 21 | }
 22 | 
 23 | func TrainParams(data map[string][]tokens.Freqs, p *TrainerParams) *Classifier {
 24 | 	crossFreqs, trainingFreqs := partitionSamples(data, p.CrossValidation)
 25 | 	tokens, samples := vectorizeSamples(trainingFreqs)
 26 | 
 27 | 	solver := svm.GradientDescentSolver{
 28 | 		Timeout:  farAwayTimeout,
 29 | 		Tradeoff: p.Tradeoff,
 30 | 	}
 31 | 
 32 | 	var bestClassifier *Classifier
 33 | 	var bestValidationScore float64
 34 | 
 35 | 	for _, kernel := range p.Kernels {
 36 | 		if p.Verbose {
 37 | 			log.Println("Trying kernel:", kernel)
 38 | 		}
 39 | 		solverKernel := cachedKernel(kernel)
 40 | 		classifier := &Classifier{
 41 | 			Keywords:    tokens,
 42 | 			Kernel:      kernel,
 43 | 			Classifiers: map[string]BinaryClassifier{},
 44 | 		}
 45 | 
 46 | 		usedSamples := map[int]linalg.Vector{}
 47 | 		for lang := range samples {
 48 | 			if p.Verbose {
 49 | 				log.Println("Training classifier for language:", lang)
 50 | 			}
 51 | 			problem := svmProblem(samples, lang, solverKernel)
 52 | 			solution := solver.Solve(problem)
 53 | 			binClass := BinaryClassifier{
 54 | 				SupportVectors: make([]int, len(solution.SupportVectors)),
 55 | 				Weights:        make([]float64, len(solution.Coefficients)),
 56 | 				Threshold:      -solution.Threshold,
 57 | 			}
 58 | 			copy(binClass.Weights, solution.Coefficients)
 59 | 			for i, v := range solution.SupportVectors {
 60 | 				// v.UserInfo will be turned into a support
 61 | 				// vector index by makeSampleVectorList().
 62 | 				binClass.SupportVectors[i] = v.UserInfo
 63 | 				usedSamples[v.UserInfo] = linalg.Vector(v.V)
 64 | 			}
 65 | 			classifier.Classifiers[lang] = binClass
 66 | 		}
 67 | 
 68 | 		makeSampleVectorList(classifier, usedSamples)
 69 | 
 70 | 		score := correctFraction(classifier, crossFreqs)
 71 | 		if p.Verbose {
 72 | 			trainingScore := correctFraction(classifier, trainingFreqs)
 73 | 			log.Printf("Results: cross=%f training=%f support=%d/%d", score,
 74 | 				trainingScore, len(classifier.SampleVectors), countSamples(samples))
 75 | 		}
 76 | 		if score > bestValidationScore || bestClassifier == nil {
 77 | 			bestClassifier = classifier
 78 | 		}
 79 | 	}
 80 | 
 81 | 	return bestClassifier
 82 | }
 83 | 
 84 | func partitionSamples(data map[string][]tokens.Freqs, crossFrac float64) (cross,
 85 | 	training map[string][]tokens.Freqs) {
 86 | 
 87 | 	cross = map[string][]tokens.Freqs{}
 88 | 	training = map[string][]tokens.Freqs{}
 89 | 
 90 | 	for lang, samples := range data {
 91 | 		p := rand.Perm(len(samples))
 92 | 		newSamples := make([]tokens.Freqs, len(samples))
 93 | 		for i, x := range p {
 94 | 			newSamples[i] = samples[x]
 95 | 		}
 96 | 		crossCount := int(crossFrac * float64(len(samples)))
 97 | 		cross[lang] = newSamples[:crossCount]
 98 | 		training[lang] = newSamples[crossCount:]
 99 | 	}
100 | 
101 | 	return
102 | }
103 | 
104 | func vectorizeSamples(data map[string][]tokens.Freqs) ([]string, map[string][]svm.Sample) {
105 | 	seenToks := map[string]bool{}
106 | 	for _, samples := range data {
107 | 		for _, sample := range samples {
108 | 			for tok := range sample {
109 | 				seenToks[tok] = true
110 | 			}
111 | 		}
112 | 	}
113 | 	toks := make([]string, 0, len(seenToks))
114 | 	for tok := range seenToks {
115 | 		toks = append(toks, tok)
116 | 	}
117 | 
118 | 	sampleMap := map[string][]svm.Sample{}
119 | 	sampleID := 1
120 | 	for lang, samples := range data {
121 | 		vecSamples := make([]svm.Sample, 0, len(samples))
122 | 		for _, sample := range samples {
123 | 			vec := make([]float64, len(toks))
124 | 			for i, tok := range toks {
125 | 				vec[i] = sample[tok]
126 | 			}
127 | 			svmSample := svm.Sample{
128 | 				V:        vec,
129 | 				UserInfo: sampleID,
130 | 			}
131 | 			sampleID++
132 | 			vecSamples = append(vecSamples, svmSample)
133 | 		}
134 | 		sampleMap[lang] = vecSamples
135 | 	}
136 | 
137 | 	return toks, sampleMap
138 | }
139 | 
140 | func countSamples(s map[string][]svm.Sample) int {
141 | 	var count int
142 | 	for _, samples := range s {
143 | 		count += len(samples)
144 | 	}
145 | 	return count
146 | }
147 | 
148 | func cachedKernel(k *Kernel) svm.Kernel {
149 | 	return svm.CachedKernel(func(s1, s2 svm.Sample) float64 {
150 | 		return k.Product(linalg.Vector(s1.V), linalg.Vector(s2.V))
151 | 	})
152 | }
153 | 
154 | func svmProblem(data map[string][]svm.Sample, posLang string, k svm.Kernel) *svm.Problem {
155 | 	var positives, negatives []svm.Sample
156 | 	for lang, samples := range data {
157 | 		if lang == posLang {
158 | 			positives = append(positives, samples...)
159 | 		} else {
160 | 			negatives = append(negatives, samples...)
161 | 		}
162 | 	}
163 | 	return &svm.Problem{
164 | 		Positives: positives,
165 | 		Negatives: negatives,
166 | 		Kernel:    k,
167 | 	}
168 | }
169 | 
170 | func correctFraction(c *Classifier, data map[string][]tokens.Freqs) float64 {
171 | 	var correct, total int
172 | 	for lang, samples := range data {
173 | 		for _, sample := range samples {
174 | 			total++
175 | 			if c.Classify(sample) == lang {
176 | 				correct++
177 | 			}
178 | 		}
179 | 	}
180 | 	return float64(correct) / float64(total)
181 | }
182 | 
183 | func makeSampleVectorList(c *Classifier, used map[int]linalg.Vector) {
184 | 	userInfoToVecIdx := map[int]int{}
185 | 
186 | 	for userInfo, sample := range used {
187 | 		userInfoToVecIdx[userInfo] = len(c.SampleVectors)
188 | 		c.SampleVectors = append(c.SampleVectors, sample)
189 | 	}
190 | 
191 | 	for _, binClass := range c.Classifiers {
192 | 		for i, userInfo := range binClass.SupportVectors {
193 | 			binClass.SupportVectors[i] = userInfoToVecIdx[userInfo]
194 | 		}
195 | 	}
196 | }
197 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # whichlang
  2 | 
  3 | This is a suite of Machine Learning tools for identifying the language in which a piece of code is written. It could potentially be used for text editors, code hosting websites, and much more.
  4 | 
  5 | A seasoned programmer could quickly tell you that this program is written in C:
  6 | 
  7 | ```c
  8 | #include 
  9 | 
 10 | int main(int argc, const char ** argv) {
 11 |   printf("Hello, world!");
 12 | }
 13 | ```
 14 | 
 15 | The goal of `whichlang` is to teach a program to do the same. By showing a Machine Learning algorithm a ton of code, you can teach it *learn* to identify programming languages itself.
 16 | 
 17 | # Usage
 18 | 
 19 | There are four steps to using whichlang:
 20 | 
 21 |  * Configure Go and download whichlang.
 22 |  * Fetch code samples from Github or some other source.
 23 |  * Train a classifier with the code samples.
 24 |  * Use the whichlang API or server with the classifier you trained.
 25 | 
 26 | ## Configuring Go and whichlang
 27 | 
 28 | First, follow the instructions on [this page](https://golang.org/doc/install) to setup Go. Once Go is setup and you have a `GOPATH` configured, run this set of commands:
 29 | 
 30 | ```
 31 | $ go get github.com/unixpickle/whichlang
 32 | $ cd $GOPATH/src/github.com/unixpickle/whichlang
 33 | ```
 34 | 
 35 | Now you have downloaded `whichlang` and are sitting in its root source folder.
 36 | 
 37 | ## Fetching samples
 38 | 
 39 | To fetch samples from Github, you must have a Github account (having more than one Github account may be beneficial, as well). You should decide how many samples you want for each programming language. I have found that 180 is more than enough.
 40 | 
 41 | You can fetch samples and save them to a directory as follows:
 42 | 
 43 | ```
 44 | $ mkdir /path/to/samples
 45 | $ go run cmd/fetchlang/*.go /path/to/samples 180
 46 | ```
 47 | 
 48 | In the above example, I specified 180 samples per language. This will prompt you for your Github credentials (to get around strict API rate limits). If you specify a large number of samples (where 180 counts as a large number), you may hit Github's API rate limits several times during the fetching process. If this occurs, you will want to delete the partially-downloaded source directories (they will be subdirectories of your sample directory, and will contain less than 180 samples), then wait an hour before re-running `fetchlang`. The `fetchlang` sub-command will automatically skip any source directories that are already present, making it relatively easy to resume paused or rate-limited downloads.
 49 | 
 50 | ## Training a classifier
 51 | 
 52 | With whichlang, you can train a number of different kinds of classifiers on your data. Currently, you can use the following classifiers:
 53 | 
 54 |  * [ID3](https://en.wikipedia.org/wiki/ID3)
 55 |  * [K-nearest neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
 56 |  * [Artificial Neural Networks](https://en.wikipedia.org/wiki/Artificial_neural_network)
 57 |  * [Support Vector Machines](https://en.wikipedia.org/wiki/Support_vector_machine)
 58 | 
 59 | Out of these algorithms, I have found that Support Vector Machines are the simplest to train and work very well. Artificial Neural Networks are a close second, but they have more hyper-parameters and are thus harder to tune well. In this document, I will describe how to train both of these classifiers, leaving out ID3 and K-nearest neighbors.
 60 | 
 61 | ### Choosing the "ubiquity"
 62 | 
 63 | For any classifier you use, you must choose a "ubiquity" value. Since whichlang works by extracting keywords from source files, it is important to discern potentially important keywords from file-specific keywords like variable names or embedded strings. To do this, keywords which appear in less than `N` files are ignored during training and classification, where `N` is the "ubiquity". I have found that a ubiquity of 10-20 works when you have roughly 100 source files.
 64 | 
 65 | ### Support Vector Machines
 66 | 
 67 | The easiest way to train a Support Vector Machine is to allow whichlang to select all the hyper-parameters for you. Note, however, that this option is *very* slow, so you may want to keep reading.
 68 | 
 69 | ```
 70 | $ go run cmd/trainer/*.go svm 15 /path/to/samples /path/to/classifier.json
 71 | ```
 72 | 
 73 | In the above command, I specified a ubiquity of 15 files. This will train an SVM on the given sample directory, outputing a classifier to `/path/to/classifier.json`. As this command runs, it will go through many different possible SVM configurations, choosing the one which performs the best on new samples (as measured via cross-validation). Since this command has to try many possible configurations, it will take a long time to run (perhaps hours or even days). I have already gone through the trouble of finding good parameters, and I will now share my results.
 74 | 
 75 | I have found that a linear SVM works fairly well for programming language classification. In particular, I've gotten a linear SVM to reach a 93% success rate on new samples, and most of those mistakes were reasonable (e.g. mistaking C for C++, or mistaking Ruby for CoffeeScript). To train a linear SVM, you can set the `SVM_KERNEL` environment variable before running the `trainer` sub-command:
 76 | 
 77 | ```
 78 | $ export SVM_KERNEL=linear
 79 | ```
 80 | 
 81 | If you want verbose output during training, you can specify another environment variable:
 82 | 
 83 | ```
 84 | $ export SVM_VERBOSE=1
 85 | ```
 86 | 
 87 | Once you have trained a linear SVM, you can perform a special compression step which will make the classifier faster and smaller. This is a technique which only works for linear SVMs! Run the following command:
 88 | 
 89 | ```
 90 | $ go run cmd/svm-shrink/*.go /path/to/classifier.json /path/to/optimized.json
 91 | ```
 92 | 
 93 | This will create a classifier file at `/path/to/optimized.json` which is the optimized version of `/path/to/classifier.json`. **Remember, this only works for linear SVMs.**
 94 | 
 95 | For other SVM environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/svm#pkg-constants).
 96 | 
 97 | ### Artificial Neural Networks
 98 | 
 99 | While whichlang does allow you to train ANNs without specifying any hyper-parameters (via grid search), doing so will take a tremendous amount of time. It is highly recommended that you manually specify the parameters for your neural network. I will give one example of training an ANN, but it is up to you to tweak these parameters:
100 | 
101 | ```
102 | $ export NEURALNET_VERBOSE=1
103 | $ export NEURALNET_VERBOSE_STEPS=1
104 | $ export NEURALNET_STEP_SIZE=0.01
105 | $ export NEURALNET_MAX_ITERS=100
106 | $ export NEURALNET_HIDDEN_SIZE=150
107 | $ go run cmd/trainer/*.go neuralnet 15 /path/to/samples /path/to/classifier.json
108 | ```
109 | 
110 | For more ANN environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/neuralnet#pkg-variables).
111 | 
112 | ## Using a classifier
113 | 
114 | Using a classifier is as simple as loading in a file. You can checkout the [classify command](https://github.com/unixpickle/whichlang/blob/master/cmd/classify/main.go) for a very simple (15-line) example.
115 | 


--------------------------------------------------------------------------------