├── README.md ├── cmd ├── classify │ └── main.go ├── fetchlang-pastie │ └── main.go ├── fetchlang │ ├── file_search.go │ ├── github.go │ ├── main.go │ └── repo_search.go ├── rater │ ├── main.go │ ├── rate.go │ └── types.go ├── server │ ├── assets │ │ ├── index.html │ │ ├── main.js │ │ └── style.css │ └── main.go ├── subsamples │ └── main.go ├── svm-shrink │ └── main.go └── trainer │ └── main.go ├── gaussbayes ├── classifier.go └── train.go ├── idtree ├── classifier.go ├── samples.go └── train.go ├── knn ├── classifier.go └── trainer.go ├── main.go ├── neuralnet ├── classifier.go ├── data_set.go ├── env.go ├── gradients.go └── train.go ├── svm ├── classifier.go ├── kernel.go ├── trainer.go └── trainer_params.go └── tokens ├── counts.go ├── counts_test.go ├── freqs.go ├── sample_counts.go └── sample_counts_test.go /README.md: -------------------------------------------------------------------------------- 1 | # whichlang 2 | 3 | This is a suite of Machine Learning tools for identifying the language in which a piece of code is written. It could potentially be used for text editors, code hosting websites, and much more. 4 | 5 | A seasoned programmer could quickly tell you that this program is written in C: 6 | 7 | ```c 8 | #include 9 | 10 | int main(int argc, const char ** argv) { 11 | printf("Hello, world!"); 12 | } 13 | ``` 14 | 15 | The goal of `whichlang` is to teach a program to do the same. By showing a Machine Learning algorithm a ton of code, you can teach it *learn* to identify programming languages itself. 16 | 17 | # Usage 18 | 19 | There are four steps to using whichlang: 20 | 21 | * Configure Go and download whichlang. 22 | * Fetch code samples from Github or some other source. 23 | * Train a classifier with the code samples. 24 | * Use the whichlang API or server with the classifier you trained. 25 | 26 | ## Configuring Go and whichlang 27 | 28 | First, follow the instructions on [this page](https://golang.org/doc/install) to setup Go. Once Go is setup and you have a `GOPATH` configured, run this set of commands: 29 | 30 | ``` 31 | $ go get github.com/unixpickle/whichlang 32 | $ cd $GOPATH/src/github.com/unixpickle/whichlang 33 | ``` 34 | 35 | Now you have downloaded `whichlang` and are sitting in its root source folder. 36 | 37 | ## Fetching samples 38 | 39 | To fetch samples from Github, you must have a Github account (having more than one Github account may be beneficial, as well). You should decide how many samples you want for each programming language. I have found that 180 is more than enough. 40 | 41 | You can fetch samples and save them to a directory as follows: 42 | 43 | ``` 44 | $ mkdir /path/to/samples 45 | $ go run cmd/fetchlang/*.go /path/to/samples 180 46 | ``` 47 | 48 | In the above example, I specified 180 samples per language. This will prompt you for your Github credentials (to get around strict API rate limits). If you specify a large number of samples (where 180 counts as a large number), you may hit Github's API rate limits several times during the fetching process. If this occurs, you will want to delete the partially-downloaded source directories (they will be subdirectories of your sample directory, and will contain less than 180 samples), then wait an hour before re-running `fetchlang`. The `fetchlang` sub-command will automatically skip any source directories that are already present, making it relatively easy to resume paused or rate-limited downloads. 49 | 50 | ## Training a classifier 51 | 52 | With whichlang, you can train a number of different kinds of classifiers on your data. Currently, you can use the following classifiers: 53 | 54 | * [ID3](https://en.wikipedia.org/wiki/ID3) 55 | * [K-nearest neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) 56 | * [Artificial Neural Networks](https://en.wikipedia.org/wiki/Artificial_neural_network) 57 | * [Support Vector Machines](https://en.wikipedia.org/wiki/Support_vector_machine) 58 | 59 | Out of these algorithms, I have found that Support Vector Machines are the simplest to train and work very well. Artificial Neural Networks are a close second, but they have more hyper-parameters and are thus harder to tune well. In this document, I will describe how to train both of these classifiers, leaving out ID3 and K-nearest neighbors. 60 | 61 | ### Choosing the "ubiquity" 62 | 63 | For any classifier you use, you must choose a "ubiquity" value. Since whichlang works by extracting keywords from source files, it is important to discern potentially important keywords from file-specific keywords like variable names or embedded strings. To do this, keywords which appear in less than `N` files are ignored during training and classification, where `N` is the "ubiquity". I have found that a ubiquity of 10-20 works when you have roughly 100 source files. 64 | 65 | ### Support Vector Machines 66 | 67 | The easiest way to train a Support Vector Machine is to allow whichlang to select all the hyper-parameters for you. Note, however, that this option is *very* slow, so you may want to keep reading. 68 | 69 | ``` 70 | $ go run cmd/trainer/*.go svm 15 /path/to/samples /path/to/classifier.json 71 | ``` 72 | 73 | In the above command, I specified a ubiquity of 15 files. This will train an SVM on the given sample directory, outputing a classifier to `/path/to/classifier.json`. As this command runs, it will go through many different possible SVM configurations, choosing the one which performs the best on new samples (as measured via cross-validation). Since this command has to try many possible configurations, it will take a long time to run (perhaps hours or even days). I have already gone through the trouble of finding good parameters, and I will now share my results. 74 | 75 | I have found that a linear SVM works fairly well for programming language classification. In particular, I've gotten a linear SVM to reach a 93% success rate on new samples, and most of those mistakes were reasonable (e.g. mistaking C for C++, or mistaking Ruby for CoffeeScript). To train a linear SVM, you can set the `SVM_KERNEL` environment variable before running the `trainer` sub-command: 76 | 77 | ``` 78 | $ export SVM_KERNEL=linear 79 | ``` 80 | 81 | If you want verbose output during training, you can specify another environment variable: 82 | 83 | ``` 84 | $ export SVM_VERBOSE=1 85 | ``` 86 | 87 | Once you have trained a linear SVM, you can perform a special compression step which will make the classifier faster and smaller. This is a technique which only works for linear SVMs! Run the following command: 88 | 89 | ``` 90 | $ go run cmd/svm-shrink/*.go /path/to/classifier.json /path/to/optimized.json 91 | ``` 92 | 93 | This will create a classifier file at `/path/to/optimized.json` which is the optimized version of `/path/to/classifier.json`. **Remember, this only works for linear SVMs.** 94 | 95 | For other SVM environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/svm#pkg-constants). 96 | 97 | ### Artificial Neural Networks 98 | 99 | While whichlang does allow you to train ANNs without specifying any hyper-parameters (via grid search), doing so will take a tremendous amount of time. It is highly recommended that you manually specify the parameters for your neural network. I will give one example of training an ANN, but it is up to you to tweak these parameters: 100 | 101 | ``` 102 | $ export NEURALNET_VERBOSE=1 103 | $ export NEURALNET_VERBOSE_STEPS=1 104 | $ export NEURALNET_STEP_SIZE=0.01 105 | $ export NEURALNET_MAX_ITERS=100 106 | $ export NEURALNET_HIDDEN_SIZE=150 107 | $ go run cmd/trainer/*.go neuralnet 15 /path/to/samples /path/to/classifier.json 108 | ``` 109 | 110 | For more ANN environment variables you can checkout [this list](https://godoc.org/github.com/unixpickle/whichlang/neuralnet#pkg-variables). 111 | 112 | ## Using a classifier 113 | 114 | Using a classifier is as simple as loading in a file. You can checkout the [classify command](https://github.com/unixpickle/whichlang/blob/master/cmd/classify/main.go) for a very simple (15-line) example. 115 | -------------------------------------------------------------------------------- /cmd/classify/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | 8 | "github.com/unixpickle/whichlang" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) != 4 { 14 | fmt.Fprintln(os.Stderr, "Usage: classify ") 15 | os.Exit(1) 16 | } 17 | 18 | decoder := whichlang.Decoders[os.Args[1]] 19 | if decoder == nil { 20 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1]) 21 | os.Exit(1) 22 | } 23 | 24 | classifierData, err := ioutil.ReadFile(os.Args[2]) 25 | if err != nil { 26 | fmt.Fprintln(os.Stderr, err) 27 | os.Exit(1) 28 | } 29 | 30 | classifier, err := decoder(classifierData) 31 | if err != nil { 32 | fmt.Fprintln(os.Stderr, "Failed to decode classifier:", err) 33 | os.Exit(1) 34 | } 35 | 36 | contents, err := ioutil.ReadFile(os.Args[3]) 37 | if err != nil { 38 | fmt.Fprintln(os.Stderr, err) 39 | os.Exit(1) 40 | } 41 | 42 | counts := tokens.CountTokens(string(contents)) 43 | freqs := counts.Freqs() 44 | language := classifier.Classify(freqs) 45 | fmt.Println("Classification:", language) 46 | } 47 | -------------------------------------------------------------------------------- /cmd/fetchlang-pastie/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | 15 | "github.com/yhat/scrape" 16 | "golang.org/x/net/html" 17 | "golang.org/x/net/html/atom" 18 | ) 19 | 20 | const RoutineCount = 8 21 | 22 | func main() { 23 | if len(os.Args) != 4 { 24 | fmt.Fprintln(os.Stderr, "Usage: fetchlang-pastie ") 25 | os.Exit(1) 26 | } 27 | startIdx, err := strconv.Atoi(os.Args[1]) 28 | if err != nil { 29 | fmt.Fprintln(os.Stderr, err) 30 | os.Exit(1) 31 | } 32 | endIdx, err := strconv.Atoi(os.Args[2]) 33 | if err != nil { 34 | fmt.Fprintln(os.Stderr, err) 35 | os.Exit(1) 36 | } 37 | pasteIndices := make(chan int) 38 | go func() { 39 | for i := startIdx; i <= endIdx; i++ { 40 | pasteIndices <- i 41 | } 42 | close(pasteIndices) 43 | }() 44 | outDir := os.Args[3] 45 | if err := ensureDirectoryPresent(outDir); err != nil { 46 | fmt.Fprintln(os.Stderr, err) 47 | os.Exit(1) 48 | } 49 | fetchPastes(pasteIndices, outDir) 50 | } 51 | 52 | func fetchPastes(indices <-chan int, outDir string) { 53 | var wg sync.WaitGroup 54 | for i := 0; i < RoutineCount; i++ { 55 | wg.Add(1) 56 | go func() { 57 | defer wg.Done() 58 | for index := range indices { 59 | if err := fetchPaste(index, outDir); err != nil { 60 | fmt.Fprintln(os.Stderr, "error for paste", index, err) 61 | } else { 62 | fmt.Println("succeeded for paste", index) 63 | } 64 | } 65 | }() 66 | } 67 | wg.Wait() 68 | } 69 | 70 | func fetchPaste(index int, outDir string) error { 71 | code, lang, err := fetchPasteCode(index) 72 | if err != nil { 73 | return err 74 | } 75 | codeDir := filepath.Join(outDir, lang) 76 | ensureDirectoryPresent(codeDir) 77 | fileName := strconv.Itoa(index) + ".txt" 78 | return ioutil.WriteFile(filepath.Join(codeDir, fileName), []byte(code), 0755) 79 | } 80 | 81 | func fetchPasteCode(index int) (contents, language string, err error) { 82 | response, err := http.Get("http://pastie.org/pastes/" + strconv.Itoa(index)) 83 | if err != nil { 84 | return 85 | } 86 | body, err := ioutil.ReadAll(response.Body) 87 | response.Body.Close() 88 | if err != nil { 89 | return 90 | } 91 | pageData := string(body) 92 | exp := regexp.MustCompile("\\\n(.*?)\n ") 93 | match := exp.FindStringSubmatch(pageData) 94 | if match == nil { 95 | ioutil.WriteFile("/Users/alex/Desktop/foo.html", []byte(pageData), 0755) 96 | return "", "", errors.New("cannot locate language") 97 | } 98 | language = match[1] 99 | language = strings.Replace(language, "/", ":", -1) 100 | 101 | response, err = http.Get("http://pastie.org/pastes/" + strconv.Itoa(index) + "/text") 102 | if err != nil { 103 | return 104 | } 105 | root, err := html.Parse(response.Body) 106 | response.Body.Close() 107 | if err != nil { 108 | return 109 | } 110 | codeBlock, ok := scrape.Find(root, scrape.ByTag(atom.Pre)) 111 | if !ok { 112 | return "", "", errors.New("no
 tag")
113 | 	}
114 | 	contents = codeBlockText(codeBlock)
115 | 	return
116 | }
117 | 
118 | func codeBlockText(n *html.Node) string {
119 | 	if n.DataAtom == atom.Br {
120 | 		return "\n"
121 | 	}
122 | 	if n.Type == html.TextNode {
123 | 		return n.Data
124 | 	}
125 | 
126 | 	var res string
127 | 	child := n.FirstChild
128 | 	for child != nil {
129 | 		res += codeBlockText(child)
130 | 		child = child.NextSibling
131 | 	}
132 | 	return res
133 | }
134 | 
135 | func ensureDirectoryPresent(dirPath string) error {
136 | 	if _, err := os.Stat(dirPath); err != nil {
137 | 		if err := os.Mkdir(dirPath, 0755); err != nil {
138 | 			return err
139 | 		}
140 | 	}
141 | 	return nil
142 | }
143 | 


--------------------------------------------------------------------------------
/cmd/fetchlang/file_search.go:
--------------------------------------------------------------------------------
  1 | package main
  2 | 
  3 | import (
  4 | 	"encoding/base64"
  5 | 	"encoding/json"
  6 | 	"errors"
  7 | 	"net/url"
  8 | 	"path"
  9 | 	"strings"
 10 | )
 11 | 
 12 | var (
 13 | 	ErrNoResults   = errors.New("no results")
 14 | 	ErrMaxRequests = errors.New("too many requests")
 15 | )
 16 | 
 17 | // A FileSearch defines parameters for
 18 | // searching a repository for files.
 19 | type FileSearch struct {
 20 | 	// Repository is the repository name,
 21 | 	// formatted as "user/repo".
 22 | 	Repository string
 23 | 
 24 | 	MinFileSize int
 25 | 	MaxFileSize int
 26 | 	Extensions  []string
 27 | 
 28 | 	// MaxRequests is the maximum number of
 29 | 	// API requests to be performed by the
 30 | 	// search before giving up.
 31 | 	MaxRequests int
 32 | }
 33 | 
 34 | // SearchFile runs a FileSearch.
 35 | // It returns ErrMaxRequests if more than
 36 | // s.MaxRequests requests are used.
 37 | // It returns ErrNoResults if no results
 38 | // are found.
 39 | func (g *GithubClient) SearchFile(s FileSearch) (contents []byte, err error) {
 40 | 	return g.firstFileSearch(&s, "/")
 41 | }
 42 | 
 43 | func (g *GithubClient) firstFileSearch(s *FileSearch, dir string) (match []byte, err error) {
 44 | 	if s.MaxRequests == 0 {
 45 | 		return nil, ErrMaxRequests
 46 | 	}
 47 | 
 48 | 	u := url.URL{
 49 | 		Scheme: "https",
 50 | 		Host:   "api.github.com",
 51 | 		Path:   path.Join("/repos", s.Repository, "/contents", dir),
 52 | 	}
 53 | 
 54 | 	body, _, err := g.request(u.String())
 55 | 	if err != nil {
 56 | 		return nil, err
 57 | 	}
 58 | 
 59 | 	s.MaxRequests--
 60 | 
 61 | 	var result []entity
 62 | 	if err := json.Unmarshal(body, &result); err != nil {
 63 | 		return nil, err
 64 | 	}
 65 | 
 66 | 	for _, ent := range result {
 67 | 		if ent.Match(s) {
 68 | 			return g.readFile(s.Repository, ent.Path)
 69 | 		}
 70 | 	}
 71 | 
 72 | 	sourceDirectoryHeuristic(result, s.Repository)
 73 | 
 74 | 	for _, ent := range result {
 75 | 		if ent.Dir() {
 76 | 			match, err = g.firstFileSearch(s, ent.Path)
 77 | 			if match != nil || (err != nil && err != ErrNoResults) {
 78 | 				return
 79 | 			}
 80 | 		}
 81 | 	}
 82 | 
 83 | 	return nil, ErrNoResults
 84 | }
 85 | 
 86 | func (g *GithubClient) readFile(repo, filePath string) ([]byte, error) {
 87 | 	u := url.URL{
 88 | 		Scheme: "https",
 89 | 		Host:   "api.github.com",
 90 | 		Path:   path.Join("/repos", repo, "/contents", filePath),
 91 | 	}
 92 | 	body, _, err := g.request(u.String())
 93 | 	if err != nil {
 94 | 		return nil, err
 95 | 	}
 96 | 
 97 | 	var result struct {
 98 | 		Content  string `json:"content"`
 99 | 		Encoding string `json:"encoding"`
100 | 	}
101 | 	if err := json.Unmarshal(body, &result); err != nil {
102 | 		return nil, err
103 | 	}
104 | 
105 | 	if result.Encoding == "base64" {
106 | 		return base64.StdEncoding.DecodeString(result.Content)
107 | 	} else {
108 | 		return nil, errors.New("unknown encoding: " + result.Encoding)
109 | 	}
110 | }
111 | 
112 | type entity struct {
113 | 	Name string `json:"name"`
114 | 	Path string `json:"path"`
115 | 	Size int    `json:"size"`
116 | 	Type string `json:"type"`
117 | }
118 | 
119 | func (e *entity) Dir() bool {
120 | 	return e.Type == "dir"
121 | }
122 | 
123 | func (e *entity) Match(s *FileSearch) bool {
124 | 	if e.Type != "file" {
125 | 		return false
126 | 	}
127 | 	if e.Size < s.MinFileSize || e.Size > s.MaxFileSize {
128 | 		return false
129 | 	}
130 | 	for _, ext := range s.Extensions {
131 | 		if strings.HasSuffix(e.Name, "."+ext) {
132 | 			return true
133 | 		}
134 | 	}
135 | 	return false
136 | }
137 | 
138 | // sourceDirectoryHeuristic sorts a list of
139 | // entities so that the first ones are more
140 | // likely to contain source code.
141 | func sourceDirectoryHeuristic(results []entity, repoName string) {
142 | 	sourceDirs := []string{"src", repoName, "lib", "com", "org", "net", "css", "assets"}
143 | 	numFound := 0
144 | 	for _, sourceDir := range sourceDirs {
145 | 		for i, ent := range results[numFound:] {
146 | 			if ent.Dir() && ent.Name == sourceDir {
147 | 				results[numFound], results[i] = results[i], results[numFound]
148 | 				numFound++
149 | 				break
150 | 			}
151 | 		}
152 | 	}
153 | }
154 | 


--------------------------------------------------------------------------------
/cmd/fetchlang/github.go:
--------------------------------------------------------------------------------
  1 | package main
  2 | 
  3 | import (
  4 | 	"encoding/json"
  5 | 	"errors"
  6 | 	"fmt"
  7 | 	"io/ioutil"
  8 | 	"net/http"
  9 | 	"net/url"
 10 | 	"os"
 11 | 	"regexp"
 12 | 	"strings"
 13 | 
 14 | 	"github.com/howeyc/gopass"
 15 | )
 16 | 
 17 | // A GithubClient uses the Github API
 18 | // on behalf of a given user.
 19 | type GithubClient struct {
 20 | 	User string
 21 | 	Pass string
 22 | }
 23 | 
 24 | // PromptGithubClient prompts the user for their
 25 | // Github account details, then generates a
 26 | // *GithubClient based on these details.
 27 | func PromptGithubClient() (*GithubClient, error) {
 28 | 	fmt.Print("Username: ")
 29 | 	username := ""
 30 | 	for {
 31 | 		var ch [1]byte
 32 | 		_, err := os.Stdin.Read(ch[:])
 33 | 		if err != nil {
 34 | 			return nil, err
 35 | 		} else if ch[0] == '\n' {
 36 | 			break
 37 | 		} else if ch[0] == '\r' {
 38 | 			continue
 39 | 		}
 40 | 		username += string(ch[0])
 41 | 	}
 42 | 
 43 | 	fmt.Print("Password: ")
 44 | 	password, err := gopass.GetPasswd()
 45 | 	if err != nil {
 46 | 		return nil, err
 47 | 	}
 48 | 
 49 | 	return &GithubClient{
 50 | 		User: strings.TrimSpace(username),
 51 | 		Pass: string(password),
 52 | 	}, nil
 53 | }
 54 | 
 55 | // request accesses an API URL using the
 56 | // user's credentials.
 57 | //
 58 | // It returns an error if the request fails,
 59 | // or if Github's API returns an error.
 60 | //
 61 | // Some requests are naturally paginated, in
 62 | // which case the next return argument
 63 | // corresponds to the URL of the next page.
 64 | func (g *GithubClient) request(u string) (data []byte, next *url.URL, err error) {
 65 | 	req, err := http.NewRequest("GET", u, nil)
 66 | 	if err != nil {
 67 | 		return nil, nil, err
 68 | 	}
 69 | 	req.Header.Set("Accept", "application/vnd.github.preview.text-match+json")
 70 | 	req.SetBasicAuth(g.User, g.Pass)
 71 | 	res, err := http.DefaultClient.Do(req)
 72 | 	if res != nil {
 73 | 		defer res.Body.Close()
 74 | 	}
 75 | 	if err != nil {
 76 | 		return nil, nil, err
 77 | 	}
 78 | 	body, err := ioutil.ReadAll(res.Body)
 79 | 	if err != nil {
 80 | 		return nil, nil, err
 81 | 	}
 82 | 
 83 | 	var raw map[string]interface{}
 84 | 	if err := json.Unmarshal(body, &raw); err == nil {
 85 | 		if message, ok := raw["message"]; ok {
 86 | 			if s, ok := message.(string); ok {
 87 | 				return nil, nil, errors.New(s)
 88 | 			}
 89 | 		}
 90 | 	}
 91 | 
 92 | 	nextPattern := regexp.MustCompile("\\<(.*?)\\>; rel=\"next\"")
 93 | 	match := nextPattern.FindStringSubmatch(res.Header.Get("Link"))
 94 | 	if match != nil {
 95 | 		u, _ := url.Parse(match[1])
 96 | 		return body, u, nil
 97 | 	}
 98 | 
 99 | 	return body, nil, nil
100 | }
101 | 


--------------------------------------------------------------------------------
/cmd/fetchlang/main.go:
--------------------------------------------------------------------------------
  1 | package main
  2 | 
  3 | import (
  4 | 	"fmt"
  5 | 	"io/ioutil"
  6 | 	"os"
  7 | 	"path/filepath"
  8 | 	"strconv"
  9 | )
 10 | 
 11 | const MinFileSize = 100
 12 | const MaxFileSize = 500000
 13 | const MaxRequestsPerRepo = 20
 14 | 
 15 | type Language struct {
 16 | 	Name       string
 17 | 	Extensions []string
 18 | }
 19 | 
 20 | var Languages = []Language{
 21 | 	{"ActionScript", []string{"as"}},
 22 | 	{"C", []string{"c"}},
 23 | 	{"C#", []string{"cs"}},
 24 | 	{"C++", []string{"cpp", "c++", "C", "cc"}},
 25 | 	{"Clojure", []string{"clj"}},
 26 | 	{"CoffeeScript", []string{"coffee"}},
 27 | 	{"CSS", []string{"css"}},
 28 | 	{"Go", []string{"go"}},
 29 | 	{"Haskell", []string{"hs"}},
 30 | 	{"HTML", []string{"html", "htm"}},
 31 | 	{"Java", []string{"java"}},
 32 | 	{"JavaScript", []string{"js"}},
 33 | 	{"Lua", []string{"lua"}},
 34 | 	{"Matlab", []string{"m"}},
 35 | 	{"Objective-C", []string{"m"}},
 36 | 	{"Perl", []string{"pl"}},
 37 | 	{"PHP", []string{"php"}},
 38 | 	{"Python", []string{"py"}},
 39 | 	{"R", []string{"r", "R"}},
 40 | 	{"Ruby", []string{"rb"}},
 41 | 	{"Scala", []string{"scala"}},
 42 | 	{"Shell", []string{"sh", "bash"}},
 43 | 	{"Swift", []string{"swift"}},
 44 | 	{"TeX", []string{"tex", "TeX"}},
 45 | 	{"VimL", []string{"vim"}},
 46 | }
 47 | 
 48 | func main() {
 49 | 	if len(os.Args) != 3 {
 50 | 		fmt.Fprintln(os.Stderr, "Usage: fetchlang  ")
 51 | 		os.Exit(1)
 52 | 	}
 53 | 	sampleDir := os.Args[1]
 54 | 	sampleCount, err := strconv.Atoi(os.Args[2])
 55 | 	if err != nil {
 56 | 		fmt.Fprintln(os.Stderr, "Invalid sample count:", os.Args[2])
 57 | 		os.Exit(1)
 58 | 	}
 59 | 
 60 | 	client, err := PromptGithubClient()
 61 | 	if err != nil {
 62 | 		fmt.Fprintln(os.Stderr, "Failed to read credentials:", err)
 63 | 		os.Exit(1)
 64 | 	}
 65 | 
 66 | 	for _, lang := range Languages {
 67 | 		fmt.Println("Fetching samples for", lang.Name, "...")
 68 | 		err := fetchLanguage(client, sampleDir, sampleCount, lang)
 69 | 		if err != nil {
 70 | 			fmt.Println("Error:", err)
 71 | 		}
 72 | 	}
 73 | }
 74 | 
 75 | func fetchLanguage(github *GithubClient, sampleDir string, count int, lang Language) error {
 76 | 	if err := os.Mkdir(filepath.Join(sampleDir, lang.Name), 0755); err != nil {
 77 | 		return err
 78 | 	}
 79 | 
 80 | 	doneChan := make(chan struct{}, 1)
 81 | 	repoChan, errChan := github.Search(lang.Name, doneChan)
 82 | 
 83 | 	defer func() {
 84 | 		close(doneChan)
 85 | 	}()
 86 | 
 87 | 	resCount := 0
 88 | 	for repo := range repoChan {
 89 | 		file, err := github.SearchFile(FileSearch{
 90 | 			Repository:  repo,
 91 | 			MinFileSize: MinFileSize,
 92 | 			MaxFileSize: MaxFileSize,
 93 | 			Extensions:  lang.Extensions,
 94 | 			MaxRequests: MaxRequestsPerRepo,
 95 | 		})
 96 | 		if err == ErrNoResults {
 97 | 			fmt.Println("No results for:", repo)
 98 | 			continue
 99 | 		} else if err == ErrMaxRequests {
100 | 			fmt.Println("Max requests exceeded:", repo)
101 | 			continue
102 | 		} else if err != nil {
103 | 			return err
104 | 		} else if file == nil {
105 | 			continue
106 | 		}
107 | 
108 | 		fileName := fmt.Sprintf("%d.%s", resCount, lang.Extensions[0])
109 | 		targetFile := filepath.Join(sampleDir, lang.Name, fileName)
110 | 		if err := ioutil.WriteFile(targetFile, file, 0755); err != nil {
111 | 			return err
112 | 		}
113 | 
114 | 		resCount++
115 | 		if resCount == count {
116 | 			break
117 | 		}
118 | 	}
119 | 
120 | 	select {
121 | 	case err := <-errChan:
122 | 		return err
123 | 	default:
124 | 		return nil
125 | 	}
126 | }
127 | 


--------------------------------------------------------------------------------
/cmd/fetchlang/repo_search.go:
--------------------------------------------------------------------------------
 1 | package main
 2 | 
 3 | import (
 4 | 	"encoding/json"
 5 | 	"net/url"
 6 | )
 7 | 
 8 | // Search asynchronously lists repositories which
 9 | // are written in the given programming language.
10 | // Repository names are of the form "user/repo".
11 | //
12 | // The caller should close the done argument when
13 | // they do not need any more results.
14 | // When no results are left, or when done is closed,
15 | // or on error, both returned channels are closed.
16 | //
17 | // If search results cannot be obtained, an error
18 | // is sent on the error channel.
19 | func (g *GithubClient) Search(lang string, done <-chan struct{}) (<-chan string, <-chan error) {
20 | 	nameChan := make(chan string, 0)
21 | 	errChan := make(chan error, 1)
22 | 	go func() {
23 | 		defer func() {
24 | 			close(nameChan)
25 | 			close(errChan)
26 | 		}()
27 | 		u := repositorySearchURL(lang)
28 | 		for u != nil {
29 | 			select {
30 | 			case <-done:
31 | 				return
32 | 			default:
33 | 			}
34 | 
35 | 			body, next, err := g.request(u.String())
36 | 			if err != nil {
37 | 				errChan <- err
38 | 				return
39 | 			}
40 | 			u = next
41 | 
42 | 			var obj struct {
43 | 				Items []struct {
44 | 					FullName string `json:"full_name"`
45 | 				} `json:"items"`
46 | 			}
47 | 			if err := json.Unmarshal(body, &obj); err != nil {
48 | 				errChan <- err
49 | 				return
50 | 			}
51 | 
52 | 			for _, x := range obj.Items {
53 | 				select {
54 | 				case nameChan <- x.FullName:
55 | 				case <-done:
56 | 					return
57 | 				}
58 | 			}
59 | 		}
60 | 	}()
61 | 	return nameChan, errChan
62 | }
63 | 
64 | func repositorySearchURL(lang string) *url.URL {
65 | 	return &url.URL{
66 | 		Scheme: "https",
67 | 		Host:   "api.github.com",
68 | 		Path:   "search/repositories",
69 | 		RawQuery: url.Values{
70 | 			"order": []string{"desc"},
71 | 			"q":     []string{"language:" + lang},
72 | 		}.Encode(),
73 | 	}
74 | }
75 | 


--------------------------------------------------------------------------------
/cmd/rater/main.go:
--------------------------------------------------------------------------------
 1 | package main
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"io/ioutil"
 6 | 	"os"
 7 | 
 8 | 	"github.com/unixpickle/whichlang"
 9 | 	"github.com/unixpickle/whichlang/tokens"
10 | )
11 | 
12 | func main() {
13 | 	if len(os.Args) != 4 {
14 | 		fmt.Fprintln(os.Stderr, "Usage: rater   ")
15 | 		os.Exit(1)
16 | 	}
17 | 
18 | 	decoder := whichlang.Decoders[os.Args[1]]
19 | 	if decoder == nil {
20 | 		fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1])
21 | 		os.Exit(1)
22 | 	}
23 | 
24 | 	classifierData, err := ioutil.ReadFile(os.Args[2])
25 | 	if err != nil {
26 | 		fmt.Fprintln(os.Stderr, err)
27 | 		os.Exit(1)
28 | 	}
29 | 
30 | 	classifier, err := decoder(classifierData)
31 | 	if err != nil {
32 | 		fmt.Fprintln(os.Stderr, "Failed to decode classifier:", err)
33 | 		os.Exit(1)
34 | 	}
35 | 
36 | 	samples, err := tokens.ReadSampleCounts(os.Args[3])
37 | 	if err != nil {
38 | 		fmt.Fprintln(os.Stderr, "Failed to read samples:", err)
39 | 		os.Exit(1)
40 | 	}
41 | 
42 | 	rating := Rate(classifier, samples)
43 | 
44 | 	fmt.Printf("Success rate: %d/%d or %0.2f%%\n", rating.Correct, rating.Total,
45 | 		100*rating.Frac())
46 | 
47 | 	nameLength := len(rating.LongestLangName())
48 | 	for _, rating := range rating.LangRatings {
49 | 		paddedName := rating.Language
50 | 		for len(paddedName) < nameLength {
51 | 			paddedName = " " + paddedName
52 | 		}
53 | 		fmt.Printf("%s  success rate %d/%d or %0.2f%%\n", paddedName,
54 | 			rating.Correct, rating.Total, 100*rating.Frac())
55 | 	}
56 | }
57 | 


--------------------------------------------------------------------------------
/cmd/rater/rate.go:
--------------------------------------------------------------------------------
 1 | package main
 2 | 
 3 | import (
 4 | 	"runtime"
 5 | 	"sync"
 6 | 
 7 | 	"github.com/unixpickle/whichlang"
 8 | 	"github.com/unixpickle/whichlang/tokens"
 9 | )
10 | 
11 | type Challenge struct {
12 | 	Language string
13 | 	Sample   tokens.Freqs
14 | }
15 | 
16 | type Result struct {
17 | 	Language string
18 | 	Correct  bool
19 | }
20 | 
21 | func Rate(c whichlang.Classifier, s map[string][]tokens.Counts) *OverallRating {
22 | 	var wg sync.WaitGroup
23 | 	challengeChan := make(chan Challenge, 0)
24 | 	resultChan := make(chan Result, 0)
25 | 
26 | 	for i := 0; i < runtime.GOMAXPROCS(0); i++ {
27 | 		wg.Add(1)
28 | 		go func() {
29 | 			defer wg.Done()
30 | 			for challenge := range challengeChan {
31 | 				correct := (c.Classify(challenge.Sample) == challenge.Language)
32 | 				resultChan <- Result{
33 | 					Language: challenge.Language,
34 | 					Correct:  correct,
35 | 				}
36 | 			}
37 | 		}()
38 | 	}
39 | 
40 | 	go func() {
41 | 		for lang, langSamples := range s {
42 | 			for _, sample := range langSamples {
43 | 				challengeChan <- Challenge{lang, sample.Freqs()}
44 | 			}
45 | 		}
46 | 		close(challengeChan)
47 | 	}()
48 | 
49 | 	go func() {
50 | 		wg.Wait()
51 | 		close(resultChan)
52 | 	}()
53 | 
54 | 	var total, successes int
55 | 	langSuccesses := map[string]int{}
56 | 	langTotals := map[string]int{}
57 | 
58 | 	for result := range resultChan {
59 | 		total++
60 | 		langTotals[result.Language]++
61 | 		if result.Correct {
62 | 			successes++
63 | 			langSuccesses[result.Language]++
64 | 		}
65 | 	}
66 | 
67 | 	langRatings := makeLangRatings(langSuccesses, langTotals)
68 | 	return NewOverallRating(successes, total, langRatings)
69 | }
70 | 
71 | func makeLangRatings(succ, total map[string]int) []*LangRating {
72 | 	res := make([]*LangRating, 0, len(total))
73 | 	for lang, totalCount := range total {
74 | 		lr := &LangRating{
75 | 			Rating: Rating{
76 | 				Correct: succ[lang],
77 | 				Total:   totalCount,
78 | 			},
79 | 			Language: lang,
80 | 		}
81 | 		res = append(res, lr)
82 | 	}
83 | 	return res
84 | }
85 | 


--------------------------------------------------------------------------------
/cmd/rater/types.go:
--------------------------------------------------------------------------------
 1 | package main
 2 | 
 3 | import "sort"
 4 | 
 5 | type Rating struct {
 6 | 	Correct int
 7 | 	Total   int
 8 | }
 9 | 
10 | func (r *Rating) Frac() float64 {
11 | 	return float64(r.Correct) / float64(r.Total)
12 | }
13 | 
14 | type OverallRating struct {
15 | 	Rating
16 | 	LangRatings []*LangRating
17 | }
18 | 
19 | func NewOverallRating(correct, total int, l []*LangRating) *OverallRating {
20 | 	sorter := ratingSorter(make([]*LangRating, len(l)))
21 | 	copy(sorter, l)
22 | 	sort.Sort(sorter)
23 | 	return &OverallRating{
24 | 		Rating: Rating{
25 | 			Correct: correct,
26 | 			Total:   total,
27 | 		},
28 | 		LangRatings: []*LangRating(sorter),
29 | 	}
30 | }
31 | 
32 | func (o *OverallRating) LongestLangName() string {
33 | 	var longest string
34 | 	for _, lang := range o.LangRatings {
35 | 		if len(lang.Language) > len(longest) {
36 | 			longest = lang.Language
37 | 		}
38 | 	}
39 | 	return longest
40 | }
41 | 
42 | type LangRating struct {
43 | 	Rating
44 | 	Language string
45 | }
46 | 
47 | type ratingSorter []*LangRating
48 | 
49 | func (r ratingSorter) Len() int {
50 | 	return len(r)
51 | }
52 | 
53 | func (r ratingSorter) Less(i, j int) bool {
54 | 	return r[i].Frac() > r[j].Frac()
55 | }
56 | 
57 | func (r ratingSorter) Swap(i, j int) {
58 | 	r[i], r[j] = r[j], r[i]
59 | }
60 | 


--------------------------------------------------------------------------------
/cmd/server/assets/index.html:
--------------------------------------------------------------------------------
 1 | 
 2 | 
 3 |   
 4 |     
 5 |     
 7 |     whichlang
 8 |     
 9 |     
10 |   
11 |   
12 |     
13 |     
14 | 15 |
16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /cmd/server/assets/main.js: -------------------------------------------------------------------------------- 1 | (function() { 2 | 3 | var textInput; 4 | var request = null; 5 | var classificationLabel; 6 | 7 | function classify() { 8 | classificationLabel.style.display = 'block'; 9 | classificationLabel.innerText = 'Loading...'; 10 | if (request !== null) { 11 | request.abort(); 12 | } 13 | request = new XMLHttpRequest(); 14 | request.onreadystatechange = function() { 15 | if (request.readyState === 4) { 16 | handleClassification(request.responseText); 17 | request = null; 18 | } 19 | }; 20 | var time = new Date().getTime(); 21 | request.open('POST', '/classify?time=' + time, true); 22 | request.send(textInput.value); 23 | } 24 | 25 | function handleClassification(classification) { 26 | var obj = JSON.parse(classification); 27 | 28 | classificationLabel.innerText = 'Classification: ' + obj.lang; 29 | classificationLabel.style.display = 'block'; 30 | } 31 | 32 | window.addEventListener('load', function() { 33 | textInput = document.getElementById('text-input'); 34 | classificationLabel = document.getElementById('classification'); 35 | var classifyButton = document.getElementById('classify-button'); 36 | classifyButton.addEventListener('click', classify); 37 | }); 38 | 39 | })(); 40 | -------------------------------------------------------------------------------- /cmd/server/assets/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | text-align: center; 3 | font-family: sans-serif; 4 | } 5 | 6 | #text-input { 7 | resize: none; 8 | margin-top: 10px; 9 | height: 222px; 10 | box-sizing: border-box; 11 | padding: 5px 5px; 12 | border: 1px solid #375eaa; 13 | border-radius: 5px; 14 | font-size: 14px; 15 | background-color: #ffffd8; 16 | font-family: monospace; 17 | } 18 | 19 | #text-input:focus { 20 | outline: 0; 21 | } 22 | 23 | @media (max-width: 400px) { 24 | #text-input { 25 | width: calc(100% - 20px); 26 | } 27 | } 28 | 29 | @media (min-width: 400px) { 30 | #text-input { 31 | width: 360px; 32 | } 33 | } 34 | 35 | #classify-button { 36 | margin: 5px 0 0 0; 37 | width: 110px; 38 | height: 35px; 39 | border: 1px solid #375eaa; 40 | border-radius: 5px; 41 | background-color: #e0ebf5; 42 | color: black; 43 | font-size: 16px; 44 | cursor: pointer; 45 | } 46 | 47 | #classification { 48 | margin-top: 10px; 49 | } 50 | -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "os" 9 | 10 | "github.com/unixpickle/whichlang" 11 | "github.com/unixpickle/whichlang/tokens" 12 | ) 13 | 14 | func main() { 15 | if len(os.Args) != 5 { 16 | fmt.Fprintln(os.Stderr, "Usage: server ") 17 | os.Exit(1) 18 | } 19 | 20 | classifier := readClassifier() 21 | assets := os.Args[3] 22 | 23 | http.HandleFunc("/classify", func(w http.ResponseWriter, r *http.Request) { 24 | contents, err := ioutil.ReadAll(r.Body) 25 | if err != nil { 26 | http.Error(w, err.Error(), http.StatusInternalServerError) 27 | return 28 | } 29 | counts := tokens.CountTokens(string(contents)) 30 | freqs := counts.Freqs() 31 | lang := classifier.Classify(freqs) 32 | jsonObj := map[string]interface{}{"lang": lang} 33 | jsonData, _ := json.Marshal(jsonObj) 34 | w.Header().Set("Content-Type", "application/json") 35 | w.Write(jsonData) 36 | }) 37 | http.Handle("/", http.FileServer(http.Dir(assets))) 38 | 39 | if err := http.ListenAndServe(":"+os.Args[4], nil); err != nil { 40 | fmt.Fprintln(os.Stderr, err) 41 | os.Exit(1) 42 | } 43 | } 44 | 45 | func readClassifier() whichlang.Classifier { 46 | decoder := whichlang.Decoders[os.Args[1]] 47 | if decoder == nil { 48 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", os.Args[1]) 49 | os.Exit(1) 50 | } 51 | 52 | data, err := ioutil.ReadFile(os.Args[2]) 53 | if err != nil { 54 | fmt.Fprintln(os.Stderr, err) 55 | os.Exit(1) 56 | } 57 | 58 | c, err := decoder(data) 59 | if err != nil { 60 | fmt.Fprintln(os.Stderr, err) 61 | os.Exit(1) 62 | } 63 | 64 | return c 65 | } 66 | -------------------------------------------------------------------------------- /cmd/subsamples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) != 3 { 14 | fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) 15 | os.Exit(1) 16 | } 17 | numLines, err := strconv.Atoi(os.Args[1]) 18 | if err != nil { 19 | fmt.Fprintln(os.Stderr, err) 20 | os.Exit(1) 21 | } 22 | sampleDir := os.Args[2] 23 | if err := subsample(numLines, sampleDir); err != nil { 24 | fmt.Fprintln(os.Stderr, err) 25 | os.Exit(1) 26 | } 27 | } 28 | 29 | func subsample(numLines int, dirPath string) error { 30 | dir, err := os.Open(dirPath) 31 | if err != nil { 32 | return err 33 | } 34 | listing, err := dir.Readdir(-1) 35 | dir.Close() 36 | if err != nil { 37 | return err 38 | } 39 | for _, langDir := range listing { 40 | if !langDir.IsDir() { 41 | continue 42 | } 43 | p := filepath.Join(dirPath, langDir.Name()) 44 | if err := subsampleLang(numLines, p); err != nil { 45 | return err 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | func subsampleLang(numLines int, langDirPath string) error { 52 | dir, err := os.Open(langDirPath) 53 | if err != nil { 54 | return err 55 | } 56 | listing, err := dir.Readdir(-1) 57 | dir.Close() 58 | if err != nil { 59 | return err 60 | } 61 | for _, fileInfo := range listing { 62 | if strings.HasPrefix(fileInfo.Name(), ".") { 63 | continue 64 | } 65 | filePath := filepath.Join(langDirPath, fileInfo.Name()) 66 | if err := subsampleFile(numLines, filePath); err != nil { 67 | return err 68 | } 69 | } 70 | return nil 71 | } 72 | 73 | func subsampleFile(numLines int, filePath string) error { 74 | contents, err := ioutil.ReadFile(filePath) 75 | if err != nil { 76 | return err 77 | } 78 | lines := strings.Split(string(contents), "\n") 79 | if len(lines) < numLines { 80 | fmt.Println("Skipping file", filePath) 81 | return nil 82 | } 83 | startIndex := (len(lines) - numLines) / 2 84 | splitLines := lines[startIndex : startIndex+numLines] 85 | newPath := newPath(numLines, filePath) 86 | newData := strings.Join(splitLines, "\n") 87 | return ioutil.WriteFile(newPath, []byte(newData), 0755) 88 | } 89 | 90 | func newPath(numLines int, filePath string) string { 91 | sampleTag := "_subsample_" + strconv.Itoa(numLines) 92 | ext := filepath.Ext(filePath) 93 | if ext == "" { 94 | return filePath + sampleTag 95 | } else { 96 | return filePath[:len(filePath)-len(ext)] + sampleTag + ext 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /cmd/svm-shrink/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | 9 | "github.com/unixpickle/num-analysis/linalg" 10 | "github.com/unixpickle/whichlang/svm" 11 | ) 12 | 13 | func main() { 14 | if len(os.Args) != 3 { 15 | fmt.Println("Usage: svm-shrink ") 16 | os.Exit(1) 17 | } 18 | 19 | data, err := ioutil.ReadFile(os.Args[1]) 20 | if err != nil { 21 | die(err) 22 | } 23 | 24 | classifier, err := svm.DecodeClassifier(data) 25 | if err != nil { 26 | die(err) 27 | } else if classifier.Kernel.Type != svm.LinearKernel { 28 | die(errors.New("can only shrink linear classifiers")) 29 | } 30 | 31 | langs := classifier.Languages() 32 | 33 | newClassifier := &svm.Classifier{ 34 | Keywords: classifier.Keywords, 35 | Kernel: classifier.Kernel, 36 | SampleVectors: make([]linalg.Vector, len(langs)), 37 | Classifiers: map[string]svm.BinaryClassifier{}, 38 | } 39 | 40 | for i, lang := range langs { 41 | newClassifier.SampleVectors[i] = combineLanguageVecs(classifier, lang) 42 | bc := svm.BinaryClassifier{ 43 | SupportVectors: []int{i}, 44 | Weights: []float64{1}, 45 | Threshold: classifier.Classifiers[lang].Threshold, 46 | } 47 | newClassifier.Classifiers[lang] = bc 48 | } 49 | 50 | encoded := newClassifier.Encode() 51 | if err := ioutil.WriteFile(os.Args[2], encoded, 0755); err != nil { 52 | die(err) 53 | } 54 | } 55 | 56 | func combineLanguageVecs(c *svm.Classifier, lang string) linalg.Vector { 57 | sum := make(linalg.Vector, len(c.Keywords)) 58 | bc := c.Classifiers[lang] 59 | for i, idx := range bc.SupportVectors { 60 | sum.Add(c.SampleVectors[idx].Copy().Scale(bc.Weights[i])) 61 | } 62 | return sum 63 | } 64 | 65 | func die(e error) { 66 | fmt.Fprintln(os.Stderr, e) 67 | os.Exit(1) 68 | } 69 | -------------------------------------------------------------------------------- /cmd/trainer/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "math/rand" 7 | "os" 8 | "strconv" 9 | "time" 10 | 11 | "github.com/unixpickle/whichlang" 12 | "github.com/unixpickle/whichlang/tokens" 13 | ) 14 | 15 | const HelpColumnSize = 10 16 | 17 | func main() { 18 | // Several machine learning algorithms depend on 19 | // random starting positions. 20 | rand.Seed(time.Now().UnixNano()) 21 | 22 | if len(os.Args) != 5 { 23 | dieUsage() 24 | } 25 | 26 | algorithm := os.Args[1] 27 | 28 | trainer := whichlang.Trainers[algorithm] 29 | if trainer == nil { 30 | fmt.Fprintln(os.Stderr, "Unknown algorithm:", algorithm) 31 | dieUsage() 32 | } 33 | 34 | ubiquity, err := strconv.Atoi(os.Args[2]) 35 | if err != nil { 36 | fmt.Fprintln(os.Stderr, "Invalid ubiquity:", ubiquity, "(expected integer)") 37 | os.Exit(1) 38 | } 39 | 40 | sampleDir := os.Args[3] 41 | outputFile := os.Args[4] 42 | 43 | counts, err := tokens.ReadSampleCounts(sampleDir) 44 | if err != nil { 45 | fmt.Fprintln(os.Stderr, err) 46 | os.Exit(1) 47 | } 48 | 49 | oldCount := counts.NumTokens() 50 | fmt.Println("Pruning tokens...") 51 | counts.Prune(ubiquity) 52 | newCount := counts.NumTokens() 53 | fmt.Printf("Pruned %d/%d tokens (%d left).\n", (oldCount - newCount), 54 | oldCount, newCount) 55 | 56 | freqs := counts.SampleFreqs() 57 | 58 | fmt.Println("Training...") 59 | classifier := trainer(freqs) 60 | 61 | fmt.Println("Saving...") 62 | data := classifier.Encode() 63 | 64 | if err := ioutil.WriteFile(outputFile, data, 0755); err != nil { 65 | fmt.Fprintln(os.Stderr, "Error writing file:", err) 66 | os.Exit(1) 67 | } 68 | } 69 | 70 | func dieUsage() { 71 | fmt.Fprintln(os.Stderr, "Usage: trainer \n\n"+ 72 | " (ubiquity specifies the number of files in which a\n keyword should appear.)\n\n"+ 73 | "Available algorithms:") 74 | for _, name := range whichlang.ClassifierNames { 75 | spaces := "" 76 | for i := len(name); i < HelpColumnSize; i++ { 77 | spaces += " " 78 | } 79 | fmt.Fprintln(os.Stderr, " "+name+spaces, whichlang.Descriptions[name]) 80 | } 81 | fmt.Fprintln(os.Stderr, "") 82 | os.Exit(1) 83 | } 84 | -------------------------------------------------------------------------------- /gaussbayes/classifier.go: -------------------------------------------------------------------------------- 1 | // Package gaussbayes implements naive Bayesian 2 | // classification using the assumption that token 3 | // frequencies follow Gaussian distributions. 4 | package gaussbayes 5 | 6 | import ( 7 | "encoding/json" 8 | "math" 9 | 10 | "github.com/unixpickle/whichlang/tokens" 11 | ) 12 | 13 | // Gaussian is a Gaussian probability distribution. 14 | type Gaussian struct { 15 | Mean float64 16 | Variance float64 17 | } 18 | 19 | // EvalLog evaluates the natural logarithm of the 20 | // density function at a given x value. 21 | func (g Gaussian) EvalLog(x float64) float64 { 22 | coeff := 1 / math.Sqrt(2*g.Variance*math.Pi) 23 | exp := -math.Pow(x-g.Mean, 2) / (2 * g.Variance) 24 | return math.Log(coeff) + exp 25 | } 26 | 27 | type Classifier struct { 28 | LangGaussians map[string]map[string]Gaussian 29 | } 30 | 31 | func DecodeClassifier(d []byte) (*Classifier, error) { 32 | var c Classifier 33 | if err := json.Unmarshal(d, &c); err != nil { 34 | return nil, err 35 | } 36 | return &c, nil 37 | } 38 | 39 | func (c *Classifier) Classify(f tokens.Freqs) string { 40 | var bestLanguage string 41 | var bestLogProbability float64 42 | for lang, dists := range c.LangGaussians { 43 | var probLog float64 44 | for token, gaussian := range dists { 45 | probLog += gaussian.EvalLog(f[token]) 46 | } 47 | if bestLanguage == "" || probLog > bestLogProbability { 48 | bestLanguage = lang 49 | bestLogProbability = probLog 50 | } 51 | } 52 | return bestLanguage 53 | } 54 | 55 | func (c *Classifier) Encode() []byte { 56 | res, _ := json.Marshal(c) 57 | return res 58 | } 59 | 60 | func (c *Classifier) Languages() []string { 61 | var languages []string 62 | for lang := range c.LangGaussians { 63 | languages = append(languages, lang) 64 | } 65 | return languages 66 | } 67 | -------------------------------------------------------------------------------- /gaussbayes/train.go: -------------------------------------------------------------------------------- 1 | package gaussbayes 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/unixpickle/whichlang/tokens" 7 | ) 8 | 9 | // Train returns a *Classifier by computing statistical 10 | // properties of the sample data. 11 | func Train(freqs map[string][]tokens.Freqs) *Classifier { 12 | res := &Classifier{LangGaussians: map[string]map[string]Gaussian{}} 13 | tokens := allTokens(freqs) 14 | for lang, samples := range freqs { 15 | gaussians := computeGaussians(samples) 16 | addMissing(gaussians, tokens) 17 | res.LangGaussians[lang] = gaussians 18 | } 19 | regularizeVariances(res) 20 | return res 21 | } 22 | 23 | func computeGaussians(samples []tokens.Freqs) map[string]Gaussian { 24 | res := map[string]Gaussian{} 25 | 26 | computeMeans(samples, res) 27 | computeVariances(samples, res) 28 | 29 | return res 30 | } 31 | 32 | func computeMeans(samples []tokens.Freqs, out map[string]Gaussian) { 33 | for _, sample := range samples { 34 | for keyword, freq := range sample { 35 | outGaussian := out[keyword] 36 | outGaussian.Mean += freq 37 | out[keyword] = outGaussian 38 | } 39 | } 40 | meanScaler := 1 / float64(len(samples)) 41 | for key, g := range out { 42 | g.Mean *= meanScaler 43 | out[key] = g 44 | } 45 | } 46 | 47 | func computeVariances(samples []tokens.Freqs, out map[string]Gaussian) { 48 | for _, sample := range samples { 49 | for keyword, freq := range sample { 50 | outGaussian := out[keyword] 51 | outGaussian.Variance += math.Pow(freq-outGaussian.Mean, 2) 52 | out[keyword] = outGaussian 53 | } 54 | } 55 | varianceScaler := 1 / float64(len(samples)) 56 | for key, g := range out { 57 | g.Variance *= varianceScaler 58 | out[key] = g 59 | } 60 | } 61 | 62 | func addMissing(m map[string]Gaussian, tokens []string) { 63 | for _, token := range tokens { 64 | if _, ok := m[token]; !ok { 65 | m[token] = Gaussian{} 66 | } 67 | } 68 | } 69 | 70 | func allTokens(m map[string][]tokens.Freqs) []string { 71 | res := map[string]bool{} 72 | for _, x := range m { 73 | for _, t := range x { 74 | for w := range t { 75 | res[w] = true 76 | } 77 | } 78 | } 79 | resSlice := make([]string, 0, len(res)) 80 | for w := range res { 81 | resSlice = append(resSlice, w) 82 | } 83 | return resSlice 84 | } 85 | 86 | // regularizeVariances ensures that no variances are zero. 87 | func regularizeVariances(c *Classifier) { 88 | var smallestVariance float64 89 | for _, m := range c.LangGaussians { 90 | for _, x := range m { 91 | if smallestVariance == 0 || (x.Variance < smallestVariance && x.Variance > 0) { 92 | smallestVariance = x.Variance 93 | } 94 | } 95 | } 96 | for _, m := range c.LangGaussians { 97 | for word, x := range m { 98 | if x.Variance == 0 { 99 | x.Variance = smallestVariance 100 | m[word] = x 101 | } 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /idtree/classifier.go: -------------------------------------------------------------------------------- 1 | package idtree 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/unixpickle/whichlang/tokens" 7 | ) 8 | 9 | type Classifier struct { 10 | LeafClassification *string 11 | 12 | Keyword string 13 | Threshold float64 14 | 15 | FalseBranch *Classifier 16 | TrueBranch *Classifier 17 | } 18 | 19 | func DecodeClassifier(d []byte) (*Classifier, error) { 20 | var res Classifier 21 | if err := json.Unmarshal(d, &res); err != nil { 22 | return nil, err 23 | } 24 | return &res, nil 25 | } 26 | 27 | func (c *Classifier) Classify(f tokens.Freqs) string { 28 | if c.LeafClassification == nil { 29 | if f[c.Keyword] > c.Threshold { 30 | return c.TrueBranch.Classify(f) 31 | } else { 32 | return c.FalseBranch.Classify(f) 33 | } 34 | } else { 35 | return *c.LeafClassification 36 | } 37 | } 38 | 39 | func (c *Classifier) Encode() []byte { 40 | res, _ := json.Marshal(c) 41 | return res 42 | } 43 | 44 | func (c *Classifier) Languages() []string { 45 | if c.LeafClassification != nil { 46 | return []string{*c.LeafClassification} 47 | } 48 | 49 | seen := map[string]bool{} 50 | for _, lang := range c.FalseBranch.Languages() { 51 | seen[lang] = true 52 | } 53 | for _, lang := range c.TrueBranch.Languages() { 54 | seen[lang] = true 55 | } 56 | 57 | res := make([]string, 0, len(seen)) 58 | for lang := range seen { 59 | res = append(res, lang) 60 | } 61 | return res 62 | } 63 | -------------------------------------------------------------------------------- /idtree/samples.go: -------------------------------------------------------------------------------- 1 | package idtree 2 | 3 | import "github.com/unixpickle/whichlang/tokens" 4 | 5 | type linearSample struct { 6 | freqs []float64 7 | lang string 8 | } 9 | 10 | func freqsToLinearSamples(toks []string, freqs map[string][]tokens.Freqs) []linearSample { 11 | var res []linearSample 12 | for lang, freqsList := range freqs { 13 | for _, freqs := range freqsList { 14 | s := linearSample{ 15 | lang: lang, 16 | freqs: make([]float64, len(toks)), 17 | } 18 | for i, tok := range toks { 19 | s.freqs[i] = freqs[tok] 20 | } 21 | res = append(res, s) 22 | } 23 | } 24 | return res 25 | } 26 | 27 | func languageMajority(samples []linearSample) string { 28 | counts := map[string]int{} 29 | for _, sample := range samples { 30 | counts[sample.lang]++ 31 | } 32 | 33 | var maxCount int 34 | var maxLang string 35 | for lang, count := range counts { 36 | if count > maxCount { 37 | maxCount = count 38 | maxLang = lang 39 | } 40 | } 41 | 42 | return maxLang 43 | } 44 | 45 | // A sampleSorter implements sort.Interface 46 | // and facilitates sorting linear samples 47 | // by the frequency of a given token. 48 | type sampleSorter struct { 49 | samples []linearSample 50 | tokenIdx int 51 | } 52 | 53 | func (s *sampleSorter) Len() int { 54 | return len(s.samples) 55 | } 56 | 57 | func (s *sampleSorter) Swap(i, j int) { 58 | s.samples[i], s.samples[j] = s.samples[j], s.samples[i] 59 | } 60 | 61 | func (s *sampleSorter) Less(i, j int) bool { 62 | f1 := s.samples[i].freqs[s.tokenIdx] 63 | f2 := s.samples[j].freqs[s.tokenIdx] 64 | return f1 < f2 65 | } 66 | -------------------------------------------------------------------------------- /idtree/train.go: -------------------------------------------------------------------------------- 1 | package idtree 2 | 3 | import ( 4 | "math" 5 | "runtime" 6 | "sort" 7 | 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | type splitInfo struct { 12 | TokenIdx int 13 | Threshold float64 14 | Entropy float64 15 | } 16 | 17 | // Train returns a *Classifier which is the 18 | // result of running ID3 on a set of training 19 | // samples. 20 | func Train(freqs map[string][]tokens.Freqs) *Classifier { 21 | toks := allTokens(freqs) 22 | samples := freqsToLinearSamples(toks, freqs) 23 | return generateClassifier(toks, samples) 24 | } 25 | 26 | func allTokens(freqs map[string][]tokens.Freqs) []string { 27 | words := make([]string, 0) 28 | seenWords := map[string]bool{} 29 | for _, freqsList := range freqs { 30 | for _, freqs := range freqsList { 31 | for word := range freqs { 32 | if !seenWords[word] { 33 | seenWords[word] = true 34 | words = append(words, word) 35 | } 36 | } 37 | } 38 | } 39 | return words 40 | } 41 | 42 | // generateClassifier generates a classifier 43 | // for the given set of samples. 44 | func generateClassifier(toks []string, s []linearSample) *Classifier { 45 | tokIdx, thresh := bestDecision(s) 46 | if tokIdx == -1 { 47 | lang := languageMajority(s) 48 | return &Classifier{ 49 | LeafClassification: &lang, 50 | } 51 | } 52 | res := &Classifier{ 53 | Keyword: toks[tokIdx], 54 | Threshold: thresh, 55 | } 56 | f, t := splitData(s, tokIdx, thresh) 57 | res.FalseBranch = generateClassifier(toks, f) 58 | res.TrueBranch = generateClassifier(toks, t) 59 | return res 60 | } 61 | 62 | func splitData(s []linearSample, tokIdx int, thresh float64) (f, t []linearSample) { 63 | f = make([]linearSample, 0, len(s)) 64 | t = make([]linearSample, 0, len(s)) 65 | 66 | for _, sample := range s { 67 | if sample.freqs[tokIdx] > thresh { 68 | t = append(t, sample) 69 | } else { 70 | f = append(f, sample) 71 | } 72 | } 73 | 74 | return 75 | } 76 | 77 | // bestDecision returns the token and threshold 78 | // which split the samples optimally (by the 79 | // criterion of entropy). 80 | // If no split exists, this returns (-1, -1). 81 | func bestDecision(s []linearSample) (tokIdx int, thresh float64) { 82 | if len(s) == 0 { 83 | return -1, -1 84 | } 85 | 86 | maxProcs := runtime.GOMAXPROCS(0) 87 | tokenCount := len(s[0].freqs) 88 | 89 | toksPerGo := tokenCount / maxProcs 90 | splitChan := make(chan *splitInfo, maxProcs) 91 | for i := 0; i < maxProcs; i++ { 92 | tokCount := toksPerGo 93 | tokStart := toksPerGo * i 94 | 95 | // The last set might need to be slightly larger 96 | // due to division truncation. 97 | if i == maxProcs-1 { 98 | tokCount = tokenCount - tokStart 99 | } 100 | 101 | go bestNodeSubset(tokStart, tokCount, s, splitChan) 102 | } 103 | 104 | var best *splitInfo 105 | for i := 0; i < maxProcs; i++ { 106 | res := <-splitChan 107 | if res == nil { 108 | continue 109 | } 110 | if best == nil || res.Entropy < best.Entropy { 111 | best = res 112 | } 113 | } 114 | 115 | if best == nil { 116 | return -1, -1 117 | } 118 | 119 | return best.TokenIdx, best.Threshold 120 | } 121 | 122 | func bestNodeSubset(startIdx, count int, s []linearSample, res chan<- *splitInfo) { 123 | bestThresh := -1.0 124 | var bestEntropy float64 125 | var bestIdx int 126 | for i := 0; i < count; i++ { 127 | idx := startIdx + i 128 | thresh, entropy := bestSplit(s, idx) 129 | if thresh < 0 { 130 | continue 131 | } else if bestThresh < 0 || entropy < bestEntropy { 132 | bestEntropy = entropy 133 | bestThresh = thresh 134 | bestIdx = idx 135 | } 136 | } 137 | if bestThresh == -1 { 138 | res <- nil 139 | } else { 140 | res <- &splitInfo{bestIdx, bestThresh, bestEntropy} 141 | } 142 | } 143 | 144 | // bestSplit finds the ideal threshold for splitting 145 | // samples by a given token (specified by an index). 146 | // This returns the threshold and the resulting entropy. 147 | // The threshold will be -1 if no split is useful. 148 | func bestSplit(unsorted []linearSample, tokenIdx int) (thresh float64, entrop float64) { 149 | samples := make([]linearSample, len(unsorted)) 150 | copy(samples, unsorted) 151 | sorter := &sampleSorter{samples, tokenIdx} 152 | sort.Sort(sorter) 153 | 154 | lowerDistribution := map[string]int{} 155 | upperDistribution := map[string]int{} 156 | 157 | for _, sample := range samples { 158 | upperDistribution[sample.lang]++ 159 | } 160 | 161 | if len(upperDistribution) == 1 { 162 | // Can't split homogeneous data effectively. 163 | return -1, -1 164 | } 165 | 166 | thresh = -1 167 | entrop = -1 168 | 169 | if len(samples) == 0 { 170 | return 171 | } 172 | 173 | lastFreq := samples[0].freqs[tokenIdx] 174 | for i := 1; i < len(samples); i++ { 175 | upperDistribution[samples[i-1].lang]-- 176 | lowerDistribution[samples[i-1].lang]++ 177 | 178 | freq := samples[i].freqs[tokenIdx] 179 | if freq == lastFreq { 180 | continue 181 | } 182 | 183 | upperFrac := float64(len(samples)-i) / float64(len(samples)) 184 | lowerFrac := float64(i) / float64(len(samples)) 185 | disorder := upperFrac*distributionEntropy(upperDistribution) + 186 | lowerFrac*distributionEntropy(lowerDistribution) 187 | if disorder < entrop || thresh == -1 { 188 | entrop = disorder 189 | thresh = (lastFreq + freq) / 2 190 | } 191 | 192 | lastFreq = freq 193 | } 194 | 195 | return 196 | } 197 | 198 | func distributionEntropy(dist map[string]int) float64 { 199 | var res float64 200 | var totalCount int 201 | for _, count := range dist { 202 | totalCount += count 203 | } 204 | for _, count := range dist { 205 | fraction := float64(count) / float64(totalCount) 206 | if fraction != 0 { 207 | res -= math.Log(fraction) * fraction 208 | } 209 | } 210 | return res 211 | } 212 | -------------------------------------------------------------------------------- /knn/classifier.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "encoding/json" 5 | "math" 6 | 7 | "github.com/unixpickle/num-analysis/linalg" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | type Sample struct { 12 | Language string 13 | Vector linalg.Vector 14 | } 15 | 16 | type Classifier struct { 17 | Tokens []string 18 | Samples []Sample 19 | 20 | NeighborCount int 21 | } 22 | 23 | func DecodeClassifier(d []byte) (*Classifier, error) { 24 | var res Classifier 25 | if err := json.Unmarshal(d, &res); err != nil { 26 | return nil, err 27 | } 28 | return &res, nil 29 | } 30 | 31 | func (c *Classifier) Classify(f tokens.Freqs) string { 32 | vec := make(linalg.Vector, len(c.Tokens)) 33 | for i, keyw := range c.Tokens { 34 | vec[i] = f[keyw] 35 | } 36 | 37 | vecMag := vec.Dot(vec) 38 | if vecMag == 0 { 39 | return c.Samples[0].Language 40 | } 41 | vec.Scale(1 / math.Sqrt(vecMag)) 42 | 43 | return c.classifyVector(vec) 44 | } 45 | 46 | func (c *Classifier) Encode() []byte { 47 | data, _ := json.Marshal(c) 48 | return data 49 | } 50 | 51 | func (c *Classifier) Languages() []string { 52 | seenLangs := map[string]bool{} 53 | for _, sample := range c.Samples { 54 | seenLangs[sample.Language] = true 55 | } 56 | res := make([]string, 0, len(seenLangs)) 57 | for lang := range seenLangs { 58 | res = append(res, lang) 59 | } 60 | return res 61 | } 62 | 63 | func (c *Classifier) classifyVector(vec linalg.Vector) string { 64 | matches := make([]match, 0, c.NeighborCount) 65 | for _, sample := range c.Samples { 66 | correlation := sample.Vector.Dot(vec) 67 | insertIdx := matchInsertionIndex(matches, correlation) 68 | if insertIdx >= c.NeighborCount { 69 | continue 70 | } 71 | if len(matches) < c.NeighborCount { 72 | matches = append(matches, match{}) 73 | } 74 | copy(matches[insertIdx+1:], matches[insertIdx:]) 75 | matches[insertIdx] = match{ 76 | Language: sample.Language, 77 | Correlation: correlation, 78 | } 79 | } 80 | 81 | return dominantClassification(matches) 82 | } 83 | 84 | func dominantClassification(matches []match) string { 85 | scores := map[string]float64{} 86 | for _, m := range matches { 87 | scores[m.Language] += m.Correlation 88 | } 89 | 90 | var bestLang string 91 | bestScore := math.Inf(-1) 92 | for lang, score := range scores { 93 | if score > bestScore { 94 | bestScore = score 95 | bestLang = lang 96 | } 97 | } 98 | 99 | return bestLang 100 | } 101 | 102 | type match struct { 103 | Language string 104 | Correlation float64 105 | } 106 | 107 | func matchInsertionIndex(m []match, corr float64) int { 108 | for i, x := range m { 109 | if x.Correlation < corr { 110 | return i 111 | } 112 | } 113 | return len(m) 114 | } 115 | -------------------------------------------------------------------------------- /knn/trainer.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | 7 | "github.com/unixpickle/num-analysis/linalg" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | // crossValidationFrac specifies the fraction of 12 | // samples which are used for cross-validation 13 | // when determining the optimal k-value. 14 | const crossValidationFrac = 0.3 15 | 16 | func Train(f map[string][]tokens.Freqs) *Classifier { 17 | seenToks := map[string]bool{} 18 | sampleCount := 0 19 | for _, samples := range f { 20 | for _, sample := range samples { 21 | for tok := range sample { 22 | seenToks[tok] = true 23 | } 24 | sampleCount++ 25 | } 26 | } 27 | 28 | toks := make([]string, 0, len(seenToks)) 29 | for tok := range seenToks { 30 | toks = append(toks, tok) 31 | } 32 | 33 | samples := make([]Sample, 0, sampleCount) 34 | for lang, freqSamples := range f { 35 | for _, freqs := range freqSamples { 36 | vec := make(linalg.Vector, len(toks)) 37 | for i, token := range toks { 38 | vec[i] = freqs[token] 39 | } 40 | if mag := vec.Dot(vec); mag != 0 { 41 | vec.Scale(1 / mag) 42 | } 43 | samples = append(samples, Sample{ 44 | Language: lang, 45 | Vector: vec, 46 | }) 47 | } 48 | } 49 | 50 | kValue := optimalKValue(samples) 51 | return &Classifier{ 52 | Tokens: toks, 53 | Samples: samples, 54 | NeighborCount: kValue, 55 | } 56 | } 57 | 58 | func optimalKValue(s []Sample) int { 59 | crossCount := int(crossValidationFrac * float64(len(s))) 60 | if crossCount == 0 { 61 | return 1 62 | } 63 | samples := shuffleSamples(s) 64 | crossSamples := samples[0:crossCount] 65 | trainingSamples := samples[crossCount:] 66 | 67 | crossMatches := sortedCrossMatches(crossSamples, trainingSamples) 68 | 69 | bestK := 1 70 | bestCorrect := 0 71 | for k := 1; k <= len(trainingSamples); k++ { 72 | crossCorrect := 0 73 | for crossIdx, matches := range crossMatches { 74 | classification := dominantClassification(matches[:k]) 75 | actualLang := crossSamples[crossIdx].Language 76 | if classification == actualLang { 77 | crossCorrect++ 78 | } 79 | } 80 | if crossCorrect > bestCorrect { 81 | bestK = k 82 | bestCorrect = crossCorrect 83 | } 84 | } 85 | 86 | return bestK 87 | } 88 | 89 | func sortedCrossMatches(cross, training []Sample) [][]match { 90 | res := make([][]match, len(cross)) 91 | for i, crossSample := range cross { 92 | res[i] = make([]match, len(training)) 93 | for j, trainingSample := range training { 94 | correlation := trainingSample.Vector.Dot(crossSample.Vector) 95 | res[i][j] = match{ 96 | Language: trainingSample.Language, 97 | Correlation: correlation, 98 | } 99 | } 100 | sort.Sort(matchSorter(res[i])) 101 | } 102 | return res 103 | } 104 | 105 | func shuffleSamples(s []Sample) []Sample { 106 | res := make([]Sample, len(s)) 107 | 108 | p := rand.Perm(len(s)) 109 | for i, x := range p { 110 | res[i] = s[x] 111 | } 112 | 113 | return res 114 | } 115 | 116 | type matchSorter []match 117 | 118 | func (m matchSorter) Len() int { 119 | return len(m) 120 | } 121 | 122 | func (m matchSorter) Less(i, j int) bool { 123 | return m[i].Correlation > m[j].Correlation 124 | } 125 | 126 | func (m matchSorter) Swap(i, j int) { 127 | m[i], m[j] = m[j], m[i] 128 | } 129 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // package whichlang is a suite of Machine Learning 2 | // tools to classify programming languages. 3 | package whichlang 4 | 5 | import ( 6 | "github.com/unixpickle/whichlang/gaussbayes" 7 | "github.com/unixpickle/whichlang/idtree" 8 | "github.com/unixpickle/whichlang/knn" 9 | "github.com/unixpickle/whichlang/neuralnet" 10 | "github.com/unixpickle/whichlang/svm" 11 | "github.com/unixpickle/whichlang/tokens" 12 | ) 13 | 14 | type Classifier interface { 15 | // Classify classifies a tokenized source file. 16 | // 17 | // The returned string is the name of the 18 | // programming language in which the file is 19 | // most likely written in. 20 | Classify(tokens.Freqs) string 21 | 22 | // Languages returns all possible languages 23 | // that Classify() might return. 24 | // The result is not sorted, and its order 25 | // may change across calls. 26 | // Callers should not modify the returned slice. 27 | Languages() []string 28 | 29 | // Encode serializes this classifier as binary 30 | // data. 31 | Encode() []byte 32 | } 33 | 34 | // A Trainer generates a Classifier using 35 | // a collection of tokenized sample files. 36 | type Trainer func(map[string][]tokens.Freqs) Classifier 37 | 38 | // A Decoder decodes a certain type of 39 | // Classifier from binary data. 40 | type Decoder func(d []byte) (Classifier, error) 41 | 42 | // ClassifierNames is an array containing the 43 | // names of every supported classifier. 44 | var ClassifierNames = []string{"idtree", "neuralnet", "knn", "svm", "gaussbayes"} 45 | 46 | // Trainers maps classifier names to their 47 | // corresponding Trainers. 48 | var Trainers = map[string]Trainer{ 49 | "idtree": func(freqs map[string][]tokens.Freqs) Classifier { 50 | return idtree.Train(freqs) 51 | }, 52 | "neuralnet": func(freqs map[string][]tokens.Freqs) Classifier { 53 | return neuralnet.Train(freqs) 54 | }, 55 | "knn": func(freqs map[string][]tokens.Freqs) Classifier { 56 | return knn.Train(freqs) 57 | }, 58 | "svm": func(freqs map[string][]tokens.Freqs) Classifier { 59 | return svm.Train(freqs) 60 | }, 61 | "gaussbayes": func(freqs map[string][]tokens.Freqs) Classifier { 62 | return gaussbayes.Train(freqs) 63 | }, 64 | } 65 | 66 | // Decoders maps classifier names to their 67 | // corresponding Decoders. 68 | var Decoders = map[string]Decoder{ 69 | "idtree": func(d []byte) (Classifier, error) { 70 | return idtree.DecodeClassifier(d) 71 | }, 72 | "neuralnet": func(d []byte) (Classifier, error) { 73 | return neuralnet.DecodeNetwork(d) 74 | }, 75 | "knn": func(d []byte) (Classifier, error) { 76 | return knn.DecodeClassifier(d) 77 | }, 78 | "svm": func(d []byte) (Classifier, error) { 79 | return svm.DecodeClassifier(d) 80 | }, 81 | "gaussbayes": func(d []byte) (Classifier, error) { 82 | return gaussbayes.DecodeClassifier(d) 83 | }, 84 | } 85 | 86 | // Descriptions maps classifier names to 87 | // one-line descriptions of the classifier. 88 | var Descriptions = map[string]string{ 89 | "idtree": "decision trees generated with ID3", 90 | "neuralnet": "feedforward neural network", 91 | "knn": "K-nearest neighbors", 92 | "svm": "support vector machines", 93 | "gaussbayes": "naive Bayes with Gaussians", 94 | } 95 | -------------------------------------------------------------------------------- /neuralnet/classifier.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import ( 4 | "encoding/json" 5 | "math" 6 | 7 | "github.com/unixpickle/num-analysis/kahan" 8 | "github.com/unixpickle/whichlang/tokens" 9 | ) 10 | 11 | // A Network is a feedforward neural network with 12 | // a single hidden layer. 13 | type Network struct { 14 | Tokens []string 15 | Langs []string 16 | 17 | // In the following weights, the last weight for 18 | // each neuron corresponds to a constant shift, 19 | // and is not multiplied by an input's value. 20 | HiddenWeights [][]float64 21 | OutputWeights [][]float64 22 | 23 | // Information used to centralize the training 24 | // weights around 0 and get them to have a 25 | // standard deviation of 1. 26 | InputShift float64 27 | InputScale float64 28 | } 29 | 30 | func DecodeNetwork(data []byte) (*Network, error) { 31 | var n Network 32 | if err := json.Unmarshal(data, &n); err != nil { 33 | return nil, err 34 | } 35 | return &n, nil 36 | } 37 | 38 | func (n *Network) Copy() *Network { 39 | res := &Network{ 40 | Tokens: make([]string, len(n.Tokens)), 41 | Langs: make([]string, len(n.Langs)), 42 | HiddenWeights: make([][]float64, len(n.HiddenWeights)), 43 | OutputWeights: make([][]float64, len(n.OutputWeights)), 44 | InputShift: n.InputShift, 45 | InputScale: n.InputScale, 46 | } 47 | copy(res.Tokens, n.Tokens) 48 | copy(res.Langs, n.Langs) 49 | for i, w := range n.HiddenWeights { 50 | res.HiddenWeights[i] = make([]float64, len(w)) 51 | copy(res.HiddenWeights[i], w) 52 | } 53 | for i, w := range n.OutputWeights { 54 | res.OutputWeights[i] = make([]float64, len(w)) 55 | copy(res.OutputWeights[i], w) 56 | } 57 | return res 58 | } 59 | 60 | func (n *Network) Classify(freqs tokens.Freqs) string { 61 | inputs := n.shiftedInput(freqs) 62 | 63 | outputSums := make([]*kahan.Summer64, len(n.OutputWeights)) 64 | for i := range outputSums { 65 | outputSums[i] = kahan.NewSummer64() 66 | outputSums[i].Add(n.outputBias(i)) 67 | } 68 | 69 | for hiddenIndex, hiddenWeights := range n.HiddenWeights { 70 | hiddenSum := kahan.NewSummer64() 71 | for j, input := range inputs { 72 | hiddenSum.Add(input * hiddenWeights[j]) 73 | } 74 | hiddenSum.Add(n.hiddenBias(hiddenIndex)) 75 | 76 | hiddenOut := sigmoid(hiddenSum.Sum()) 77 | for j, outSum := range outputSums { 78 | weight := n.OutputWeights[j][hiddenIndex] 79 | outSum.Add(weight * hiddenOut) 80 | } 81 | } 82 | 83 | maxSum := outputSums[0].Sum() 84 | maxIdx := 0 85 | for i, x := range outputSums { 86 | if x.Sum() > maxSum { 87 | maxSum = x.Sum() 88 | maxIdx = i 89 | } 90 | } 91 | return n.Langs[maxIdx] 92 | } 93 | 94 | func (n *Network) Encode() []byte { 95 | enc, _ := json.Marshal(n) 96 | return enc 97 | } 98 | 99 | func (n *Network) Languages() []string { 100 | return n.Langs 101 | } 102 | 103 | func (n *Network) outputBias(outputIdx int) float64 { 104 | return n.OutputWeights[outputIdx][len(n.HiddenWeights)] 105 | } 106 | 107 | func (n *Network) hiddenBias(hiddenIdx int) float64 { 108 | return n.HiddenWeights[hiddenIdx][len(n.Tokens)] 109 | } 110 | 111 | func (n *Network) containsNaN() bool { 112 | for _, wss := range [][][]float64{n.HiddenWeights, n.OutputWeights} { 113 | for _, ws := range wss { 114 | for _, w := range ws { 115 | if math.IsNaN(w) { 116 | return true 117 | } 118 | } 119 | } 120 | } 121 | return false 122 | } 123 | 124 | func (n *Network) shiftedInput(f tokens.Freqs) []float64 { 125 | res := make([]float64, len(n.Tokens)) 126 | for i, word := range n.Tokens { 127 | res[i] = (f[word] + n.InputShift) * n.InputScale 128 | } 129 | return res 130 | } 131 | 132 | func sigmoid(x float64) float64 { 133 | return 1.0 / (1.0 + math.Exp(-x)) 134 | } 135 | -------------------------------------------------------------------------------- /neuralnet/data_set.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "sort" 7 | 8 | "github.com/unixpickle/num-analysis/kahan" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | const ValidationFraction = 0.3 13 | 14 | // A DataSet is a set of data split into training 15 | // samples and validation samples. 16 | type DataSet struct { 17 | ValidationSamples map[string][]tokens.Freqs 18 | TrainingSamples map[string][]tokens.Freqs 19 | NormalTrainingSamples map[string][][]float64 20 | 21 | // These are statistical properties of the 22 | // training samples' frequency values. 23 | MeanFrequency float64 24 | FrequencyStddev float64 25 | } 26 | 27 | // NewDataSet creates a DataSet by randomly 28 | // partitioning some data samples into 29 | // validation and training samples. 30 | func NewDataSet(samples map[string][]tokens.Freqs) *DataSet { 31 | res := &DataSet{ 32 | ValidationSamples: map[string][]tokens.Freqs{}, 33 | TrainingSamples: map[string][]tokens.Freqs{}, 34 | } 35 | for lang, langSamples := range samples { 36 | shuffled := make([]tokens.Freqs, len(langSamples)) 37 | perm := rand.Perm(len(shuffled)) 38 | for i, x := range perm { 39 | shuffled[i] = langSamples[x] 40 | } 41 | 42 | numValid := int(float64(len(langSamples)) * ValidationFraction) 43 | res.ValidationSamples[lang] = shuffled[:numValid] 44 | res.TrainingSamples[lang] = shuffled[numValid:] 45 | } 46 | 47 | res.computeStatistics() 48 | res.computeNormalSamples() 49 | 50 | return res 51 | } 52 | 53 | // CrossScore returns the fraction of withheld 54 | // samples the Network worked for. 55 | func (c *DataSet) CrossScore(n *Network) float64 { 56 | return scoreNetwork(n, c.ValidationSamples) 57 | } 58 | 59 | // TrainingScore returns the fraction of 60 | // training samples the Network worked for. 61 | func (c *DataSet) TrainingScore(n *Network) float64 { 62 | return scoreNetwork(n, c.TrainingSamples) 63 | } 64 | 65 | // Tokens returns all of the tokens from all 66 | // of the training samples. 67 | func (c *DataSet) Tokens() []string { 68 | toks := map[string]bool{} 69 | for _, samples := range c.TrainingSamples { 70 | for _, sample := range samples { 71 | for tok := range sample { 72 | toks[tok] = true 73 | } 74 | } 75 | } 76 | 77 | res := make([]string, 0, len(toks)) 78 | for tok := range toks { 79 | res = append(res, tok) 80 | } 81 | sort.Strings(res) 82 | return res 83 | } 84 | 85 | // Langs returns all of the languages represented 86 | // by the training samples. 87 | func (c *DataSet) Langs() []string { 88 | res := make([]string, 0, len(c.TrainingSamples)) 89 | for lang := range c.TrainingSamples { 90 | res = append(res, lang) 91 | } 92 | sort.Strings(res) 93 | return res 94 | } 95 | 96 | func (c *DataSet) computeStatistics() { 97 | tokens := c.Tokens() 98 | 99 | freqSum := kahan.NewSummer64() 100 | freqCount := 0 101 | for _, langSamples := range c.TrainingSamples { 102 | for _, sample := range langSamples { 103 | freqCount += len(tokens) 104 | for _, freq := range sample { 105 | freqSum.Add(freq) 106 | } 107 | } 108 | } 109 | 110 | c.MeanFrequency = freqSum.Sum() / float64(freqCount) 111 | 112 | variationSum := kahan.NewSummer64() 113 | for _, langSamples := range c.TrainingSamples { 114 | for _, sample := range langSamples { 115 | for _, token := range tokens { 116 | freq := sample[token] 117 | variationSum.Add(math.Pow(freq-c.MeanFrequency, 2)) 118 | } 119 | } 120 | } 121 | 122 | c.FrequencyStddev = math.Sqrt(variationSum.Sum() / float64(freqCount)) 123 | } 124 | 125 | func (c *DataSet) computeNormalSamples() { 126 | c.NormalTrainingSamples = map[string][][]float64{} 127 | tokens := c.Tokens() 128 | 129 | for lang, langSamples := range c.TrainingSamples { 130 | sampleList := make([][]float64, len(langSamples)) 131 | for i, sample := range langSamples { 132 | sampleVec := make([]float64, len(tokens)) 133 | for j, token := range tokens { 134 | sampleVec[j] = (sample[token] - c.MeanFrequency) / c.FrequencyStddev 135 | } 136 | sampleList[i] = sampleVec 137 | } 138 | c.NormalTrainingSamples[lang] = sampleList 139 | } 140 | } 141 | 142 | func scoreNetwork(n *Network, samples map[string][]tokens.Freqs) float64 { 143 | var totalRight int 144 | var total int 145 | for lang, langSamples := range samples { 146 | for _, sample := range langSamples { 147 | if n.Classify(sample) == lang { 148 | totalRight++ 149 | } 150 | total++ 151 | } 152 | } 153 | return float64(totalRight) / float64(total) 154 | } 155 | -------------------------------------------------------------------------------- /neuralnet/env.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import ( 4 | "math" 5 | "os" 6 | "strconv" 7 | ) 8 | 9 | const DefaultMaxIterations = 6400 10 | 11 | // DefaultHiddenLayerScale specifies how much 12 | // larger the hidden layer is than the output 13 | // layer, by default. 14 | const DefaultHiddenLayerScale = 2.0 15 | 16 | // VerboseEnvVar is an environment variable 17 | // which can be set to "1" to make the 18 | // neuralnet print out status reports. 19 | var VerboseEnvVar = "NEURALNET_VERBOSE" 20 | 21 | // VerboseStepsEnvVar is an environment 22 | // variable which can be set to "1" to make 23 | // neuralnet print out status reports after 24 | // each iteration of gradient descent. 25 | var VerboseStepsEnvVar = "NEURALNET_VERBOSE_STEPS" 26 | 27 | // StepSizeEnvVar is an environment variable 28 | // which can be used to specify the step size 29 | // for use in gradient descent. 30 | var StepSizeEnvVar = "NEURALNET_STEP_SIZE" 31 | 32 | // MaxItersEnvVar is an environment variable 33 | // specifying the maximum number of iterations 34 | // of gradient descent to perform. 35 | var MaxItersEnvVar = "NEURALNET_MAX_ITERS" 36 | 37 | // HiddenSizeEnvVar is an environment variable 38 | // specifying the number of hidden neurons. 39 | var HiddenSizeEnvVar = "NEURALNET_HIDDEN_SIZE" 40 | 41 | func verboseFlag() bool { 42 | return os.Getenv(VerboseEnvVar) == "1" 43 | } 44 | 45 | func verboseStepsFlag() bool { 46 | return os.Getenv(VerboseStepsEnvVar) == "1" 47 | } 48 | 49 | func stepSizes() []float64 { 50 | if stepSize := os.Getenv(StepSizeEnvVar); stepSize == "" { 51 | var res []float64 52 | for power := -20; power < 10; power++ { 53 | res = append(res, math.Pow(2, float64(power))) 54 | } 55 | return res 56 | } else { 57 | val, err := strconv.ParseFloat(stepSize, 64) 58 | if err != nil { 59 | panic(err) 60 | } 61 | return []float64{val} 62 | } 63 | } 64 | 65 | func maxIterations() int { 66 | if max := os.Getenv(MaxItersEnvVar); max == "" { 67 | return DefaultMaxIterations 68 | } else { 69 | val, err := strconv.Atoi(max) 70 | if err != nil { 71 | panic(err) 72 | } 73 | return val 74 | } 75 | } 76 | 77 | func hiddenSize(outputCount int) int { 78 | if size := os.Getenv(HiddenSizeEnvVar); size == "" { 79 | return int(float64(outputCount)*DefaultHiddenLayerScale + 0.5) 80 | } else { 81 | val, err := strconv.Atoi(size) 82 | if err != nil { 83 | panic(err) 84 | } 85 | return val 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /neuralnet/gradients.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import "github.com/unixpickle/num-analysis/kahan" 4 | 5 | // A gradientCalc can compute gradients of the 6 | // error function 0.5*||Actual - Expected||^2 7 | // for a neural network on a given input. 8 | // 9 | // A gradientCalc demands a lot of scratch 10 | // memory, so it is a good idea to create one 11 | // gradientCalc and then reuse it over and over. 12 | type gradientCalc struct { 13 | n *Network 14 | 15 | hiddenOutputs []float64 16 | outputs []float64 17 | inputs []float64 18 | expectedOut []float64 19 | 20 | hiddenOutPartials []float64 21 | 22 | OutputPartials [][]float64 23 | HiddenPartials [][]float64 24 | } 25 | 26 | func newGradientCalc(n *Network) *gradientCalc { 27 | res := &gradientCalc{ 28 | n: n, 29 | hiddenOutputs: make([]float64, len(n.HiddenWeights)), 30 | outputs: make([]float64, len(n.OutputWeights)), 31 | expectedOut: make([]float64, len(n.OutputWeights)), 32 | hiddenOutPartials: make([]float64, len(n.HiddenWeights)), 33 | OutputPartials: make([][]float64, len(n.OutputWeights)), 34 | HiddenPartials: make([][]float64, len(n.HiddenWeights)), 35 | } 36 | 37 | for i := range res.OutputPartials { 38 | res.OutputPartials[i] = make([]float64, len(res.hiddenOutputs)+1) 39 | } 40 | for i := range res.HiddenPartials { 41 | res.HiddenPartials[i] = make([]float64, len(n.Tokens)+1) 42 | } 43 | 44 | return res 45 | } 46 | 47 | func (g *gradientCalc) Compute(inputs []float64, langIdx int) { 48 | g.inputs = inputs 49 | for j := range g.expectedOut { 50 | if j == langIdx { 51 | g.expectedOut[j] = 1 52 | } else { 53 | g.expectedOut[j] = 0 54 | } 55 | } 56 | 57 | g.computeOutputs() 58 | g.computeGradients() 59 | } 60 | 61 | func (g *gradientCalc) computeOutputs() { 62 | outputSums := make([]*kahan.Summer64, len(g.outputs)) 63 | 64 | for i := range outputSums { 65 | outputSums[i] = kahan.NewSummer64() 66 | outputSums[i].Add(g.n.outputBias(i)) 67 | } 68 | 69 | for hiddenIndex, hiddenWeights := range g.n.HiddenWeights { 70 | hiddenSum := kahan.NewSummer64() 71 | for j, input := range g.inputs { 72 | hiddenSum.Add(input * hiddenWeights[j]) 73 | } 74 | hiddenSum.Add(g.n.hiddenBias(hiddenIndex)) 75 | 76 | hiddenOut := sigmoid(hiddenSum.Sum()) 77 | g.hiddenOutputs[hiddenIndex] = hiddenOut 78 | for j, outSum := range outputSums { 79 | weight := g.n.OutputWeights[j][hiddenIndex] 80 | outSum.Add(weight * hiddenOut) 81 | } 82 | } 83 | 84 | for i, sum := range outputSums { 85 | g.outputs[i] = sigmoid(sum.Sum()) 86 | } 87 | } 88 | 89 | func (g *gradientCalc) computeGradients() { 90 | for i, output := range g.outputs { 91 | gradient := g.OutputPartials[i] 92 | diff := output - g.expectedOut[i] 93 | sumPartial := (1 - output) * output * diff 94 | for j, input := range g.hiddenOutputs { 95 | gradient[j] = input * sumPartial 96 | g.hiddenOutPartials[j] = g.n.OutputWeights[i][j] * sumPartial 97 | } 98 | gradient[len(g.hiddenOutputs)] = sumPartial 99 | } 100 | for i, output := range g.hiddenOutputs { 101 | gradient := g.HiddenPartials[i] 102 | sumPartial := (1 - output) * output * g.hiddenOutPartials[i] 103 | for j, input := range g.inputs { 104 | gradient[j] = input * sumPartial 105 | } 106 | gradient[len(g.inputs)] = sumPartial 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /neuralnet/train.go: -------------------------------------------------------------------------------- 1 | package neuralnet 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | 7 | "github.com/unixpickle/whichlang/tokens" 8 | ) 9 | 10 | const InitialIterationCount = 200 11 | 12 | func Train(data map[string][]tokens.Freqs) *Network { 13 | ds := NewDataSet(data) 14 | 15 | var best *Network 16 | var bestCrossScore float64 17 | var bestTrainScore float64 18 | 19 | verbose := verboseFlag() 20 | 21 | for _, stepSize := range stepSizes() { 22 | if verbose { 23 | log.Printf("trying step size %f", stepSize) 24 | } 25 | 26 | t := NewTrainer(ds, stepSize, verbose) 27 | t.Train(maxIterations()) 28 | 29 | n := t.Network() 30 | if n.containsNaN() { 31 | if verbose { 32 | log.Printf("got NaN for step size %f", stepSize) 33 | } 34 | continue 35 | } 36 | crossScore := ds.CrossScore(n) 37 | trainScore := ds.TrainingScore(n) 38 | if verbose { 39 | log.Printf("stepSize=%f crossScore=%f trainScore=%f", stepSize, 40 | crossScore, trainScore) 41 | } 42 | if (crossScore == bestCrossScore && trainScore >= bestTrainScore) || 43 | best == nil || (crossScore > bestCrossScore) { 44 | bestCrossScore = crossScore 45 | bestTrainScore = trainScore 46 | best = n 47 | } 48 | } 49 | 50 | return best 51 | } 52 | 53 | type Trainer struct { 54 | n *Network 55 | d *DataSet 56 | g *gradientCalc 57 | 58 | stepSize float64 59 | verbose bool 60 | } 61 | 62 | func NewTrainer(d *DataSet, stepSize float64, verbose bool) *Trainer { 63 | hiddenCount := hiddenSize(len(d.TrainingSamples)) 64 | n := &Network{ 65 | Tokens: d.Tokens(), 66 | Langs: d.Langs(), 67 | HiddenWeights: make([][]float64, hiddenCount), 68 | OutputWeights: make([][]float64, len(d.TrainingSamples)), 69 | InputShift: -d.MeanFrequency, 70 | InputScale: 1 / d.FrequencyStddev, 71 | } 72 | for i := range n.OutputWeights { 73 | n.OutputWeights[i] = make([]float64, hiddenCount+1) 74 | for j := range n.OutputWeights[i] { 75 | n.OutputWeights[i][j] = rand.Float64()*2 - 1 76 | } 77 | } 78 | for i := range n.HiddenWeights { 79 | n.HiddenWeights[i] = make([]float64, len(n.Tokens)+1) 80 | for j := range n.HiddenWeights[i] { 81 | n.HiddenWeights[i][j] = rand.Float64()*2 - 1 82 | } 83 | } 84 | return &Trainer{ 85 | n: n, 86 | d: d, 87 | g: newGradientCalc(n), 88 | stepSize: stepSize, 89 | verbose: verbose, 90 | } 91 | } 92 | 93 | func (t *Trainer) Train(maxIters int) { 94 | iters := InitialIterationCount 95 | if iters > maxIters { 96 | iters = maxIters 97 | } 98 | for i := 0; i < iters; i++ { 99 | if verboseStepsFlag() { 100 | log.Printf("done %d iterations, cross=%f training=%f", 101 | i, t.d.CrossScore(t.n), t.d.TrainingScore(t.n)) 102 | } 103 | t.runAllSamples() 104 | } 105 | if iters == maxIters { 106 | return 107 | } 108 | 109 | if t.n.containsNaN() { 110 | return 111 | } 112 | 113 | // Use cross-validation to find the best 114 | // number of iterations. 115 | crossScore := t.d.CrossScore(t.n) 116 | trainScore := t.d.TrainingScore(t.n) 117 | lastNet := t.n.Copy() 118 | 119 | for { 120 | if t.verbose { 121 | log.Printf("current scores: cross=%f train=%f iters=%d", 122 | crossScore, trainScore, iters) 123 | } 124 | 125 | nextAmount := iters 126 | if nextAmount+iters > maxIters { 127 | nextAmount = maxIters - iters 128 | } 129 | for i := 0; i < nextAmount; i++ { 130 | if verboseStepsFlag() { 131 | log.Printf("done %d iterations, cross=%f training=%f", 132 | i+iters, t.d.CrossScore(t.n), t.d.TrainingScore(t.n)) 133 | } 134 | t.runAllSamples() 135 | if t.n.containsNaN() { 136 | break 137 | } 138 | } 139 | iters += nextAmount 140 | 141 | if t.n.containsNaN() { 142 | t.n = lastNet 143 | break 144 | } 145 | 146 | newCrossScore := t.d.CrossScore(t.n) 147 | newTrainScore := t.d.TrainingScore(t.n) 148 | if (newCrossScore == crossScore && newTrainScore == trainScore) || 149 | newCrossScore < crossScore { 150 | t.n = lastNet 151 | return 152 | } 153 | 154 | crossScore = newCrossScore 155 | trainScore = newTrainScore 156 | 157 | if iters == maxIters { 158 | return 159 | } 160 | lastNet = t.n.Copy() 161 | } 162 | } 163 | 164 | func (t *Trainer) Network() *Network { 165 | return t.n 166 | } 167 | 168 | func (t *Trainer) runAllSamples() { 169 | var samples []struct { 170 | LangIdx int 171 | Sample []float64 172 | } 173 | 174 | for i, lang := range t.n.Langs { 175 | var sample struct { 176 | LangIdx int 177 | Sample []float64 178 | } 179 | sample.LangIdx = i 180 | 181 | trainingSamples := t.d.NormalTrainingSamples[lang] 182 | for _, s := range trainingSamples { 183 | sample.Sample = s 184 | samples = append(samples, sample) 185 | } 186 | } 187 | 188 | perm := rand.Perm(len(samples)) 189 | for _, i := range perm { 190 | t.descendSample(samples[i].Sample, samples[i].LangIdx) 191 | } 192 | } 193 | 194 | // descendSample performs gradient descent to 195 | // reduce the output error for a given sample. 196 | func (t *Trainer) descendSample(inputs []float64, langIdx int) { 197 | t.g.Compute(inputs, langIdx) 198 | 199 | for i, partials := range t.g.HiddenPartials { 200 | for j, partial := range partials { 201 | t.n.HiddenWeights[i][j] -= partial * t.stepSize 202 | } 203 | } 204 | for i, partials := range t.g.OutputPartials { 205 | for j, partial := range partials { 206 | t.n.OutputWeights[i][j] -= partial * t.stepSize 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /svm/classifier.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "encoding/json" 5 | "math" 6 | 7 | "github.com/unixpickle/num-analysis/kahan" 8 | "github.com/unixpickle/num-analysis/linalg" 9 | "github.com/unixpickle/whichlang/tokens" 10 | ) 11 | 12 | // BinaryClassifier stores info for the 13 | // binary classifiers used in a Classifier. 14 | type BinaryClassifier struct { 15 | // SupportVectors stores indices to 16 | // elements of Classifier.SampleVectors. 17 | SupportVectors []int 18 | 19 | // Weights are the corresponding weights 20 | // for each of the support vectors. 21 | Weights []float64 22 | 23 | Threshold float64 24 | } 25 | 26 | // Classifier uses one-against-all SVMs to 27 | // classify source files. 28 | type Classifier struct { 29 | Keywords []string 30 | Kernel *Kernel 31 | 32 | SampleVectors []linalg.Vector 33 | 34 | // Classifiers maps each language to its 35 | // corresponding one-against-all binary 36 | // classifier. 37 | Classifiers map[string]BinaryClassifier 38 | } 39 | 40 | func DecodeClassifier(d []byte) (*Classifier, error) { 41 | var c Classifier 42 | if err := json.Unmarshal(d, &c); err != nil { 43 | return nil, err 44 | } 45 | return &c, nil 46 | } 47 | 48 | func (c *Classifier) Classify(sample tokens.Freqs) string { 49 | products := c.sampleProducts(sample) 50 | 51 | var bestLanguage string 52 | bestClassification := math.Inf(-1) 53 | 54 | for lang, classifier := range c.Classifiers { 55 | productSum := kahan.NewSummer64() 56 | for i, vecIdx := range classifier.SupportVectors { 57 | productSum.Add(products[vecIdx] * classifier.Weights[i]) 58 | } 59 | productSum.Add(-classifier.Threshold) 60 | if productSum.Sum() > bestClassification { 61 | bestClassification = productSum.Sum() 62 | bestLanguage = lang 63 | } 64 | } 65 | 66 | return bestLanguage 67 | } 68 | 69 | func (c *Classifier) Encode() []byte { 70 | res, _ := json.Marshal(c) 71 | return res 72 | } 73 | 74 | func (c *Classifier) Languages() []string { 75 | res := make([]string, 0, len(c.Classifiers)) 76 | for lang := range c.Classifiers { 77 | res = append(res, lang) 78 | } 79 | return res 80 | } 81 | 82 | func (c *Classifier) sampleProducts(sample tokens.Freqs) []float64 { 83 | vec := c.sampleVector(sample) 84 | res := make([]float64, len(c.SampleVectors)) 85 | for i, s := range c.SampleVectors { 86 | res[i] = c.Kernel.Product(s, vec) 87 | } 88 | return res 89 | } 90 | 91 | func (c *Classifier) sampleVector(sample tokens.Freqs) linalg.Vector { 92 | vec := make(linalg.Vector, len(c.Keywords)) 93 | for i, keyword := range c.Keywords { 94 | vec[i] = sample[keyword] 95 | } 96 | return vec 97 | } 98 | -------------------------------------------------------------------------------- /svm/kernel.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strconv" 7 | 8 | "github.com/unixpickle/num-analysis/linalg" 9 | ) 10 | 11 | type KernelType int 12 | 13 | const ( 14 | // LinearKernel generates a linear classifier 15 | // with no parameters. 16 | LinearKernel KernelType = iota 17 | 18 | // PolynomialKernel computes inner products as 19 | // (x*y + k1)^k2, where k1 and k2 are parameters. 20 | PolynomialKernel 21 | 22 | // RadialBasisKernel computes inner products as 23 | // exp(-k1*||x-y||^2), where k1 is a parameter. 24 | RadialBasisKernel 25 | ) 26 | 27 | // A Kernel computes inner products of vectors 28 | // after transforming them into some space. 29 | type Kernel struct { 30 | Type KernelType 31 | Params []float64 32 | } 33 | 34 | // Product returns the product of two vectors 35 | // under this kernel. 36 | func (k *Kernel) Product(v1, v2 linalg.Vector) float64 { 37 | switch k.Type { 38 | case LinearKernel: 39 | return v1.Dot(v2) 40 | case PolynomialKernel: 41 | if len(k.Params) != 2 { 42 | panic("expected two parameters for polynomial kernel") 43 | } 44 | return math.Pow(v1.Dot(v2)+k.Params[0], k.Params[1]) 45 | case RadialBasisKernel: 46 | if len(k.Params) != 1 { 47 | panic("expected one parameter for radial basis kernel") 48 | } 49 | diff := v1.Copy().Scale(-1).Add(v2) 50 | return math.Exp(-k.Params[0] * diff.Dot(diff)) 51 | default: 52 | panic("unknown kernel type: " + strconv.Itoa(int(k.Type))) 53 | } 54 | } 55 | 56 | // String returns a mathematical formula which 57 | // represents this kernel (e.g. "(x*y+1)^2"). 58 | func (k *Kernel) String() string { 59 | switch k.Type { 60 | case LinearKernel: 61 | return "x*y" 62 | case PolynomialKernel: 63 | if len(k.Params) != 2 { 64 | panic("expected two parameters for polynomial kernel") 65 | } 66 | return fmt.Sprintf("(x*y + %f)^%f", k.Params[0], k.Params[1]) 67 | case RadialBasisKernel: 68 | if len(k.Params) != 1 { 69 | panic("expected one parameter for radial basis kernel") 70 | } 71 | return fmt.Sprintf("exp(-%f*(x*y)^2)", k.Params[0]) 72 | default: 73 | panic("unknown kernel type: " + strconv.Itoa(int(k.Type))) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /svm/trainer.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "time" 7 | 8 | "github.com/unixpickle/num-analysis/linalg" 9 | "github.com/unixpickle/weakai/svm" 10 | "github.com/unixpickle/whichlang/tokens" 11 | ) 12 | 13 | const farAwayTimeout = time.Hour * 24 * 365 14 | 15 | func Train(data map[string][]tokens.Freqs) *Classifier { 16 | params, err := EnvTrainerParams() 17 | if err != nil { 18 | panic(err) 19 | } 20 | return TrainParams(data, params) 21 | } 22 | 23 | func TrainParams(data map[string][]tokens.Freqs, p *TrainerParams) *Classifier { 24 | crossFreqs, trainingFreqs := partitionSamples(data, p.CrossValidation) 25 | tokens, samples := vectorizeSamples(trainingFreqs) 26 | 27 | solver := svm.GradientDescentSolver{ 28 | Timeout: farAwayTimeout, 29 | Tradeoff: p.Tradeoff, 30 | } 31 | 32 | var bestClassifier *Classifier 33 | var bestValidationScore float64 34 | 35 | for _, kernel := range p.Kernels { 36 | if p.Verbose { 37 | log.Println("Trying kernel:", kernel) 38 | } 39 | solverKernel := cachedKernel(kernel) 40 | classifier := &Classifier{ 41 | Keywords: tokens, 42 | Kernel: kernel, 43 | Classifiers: map[string]BinaryClassifier{}, 44 | } 45 | 46 | usedSamples := map[int]linalg.Vector{} 47 | for lang := range samples { 48 | if p.Verbose { 49 | log.Println("Training classifier for language:", lang) 50 | } 51 | problem := svmProblem(samples, lang, solverKernel) 52 | solution := solver.Solve(problem) 53 | binClass := BinaryClassifier{ 54 | SupportVectors: make([]int, len(solution.SupportVectors)), 55 | Weights: make([]float64, len(solution.Coefficients)), 56 | Threshold: -solution.Threshold, 57 | } 58 | copy(binClass.Weights, solution.Coefficients) 59 | for i, v := range solution.SupportVectors { 60 | // v.UserInfo will be turned into a support 61 | // vector index by makeSampleVectorList(). 62 | binClass.SupportVectors[i] = v.UserInfo 63 | usedSamples[v.UserInfo] = linalg.Vector(v.V) 64 | } 65 | classifier.Classifiers[lang] = binClass 66 | } 67 | 68 | makeSampleVectorList(classifier, usedSamples) 69 | 70 | score := correctFraction(classifier, crossFreqs) 71 | if p.Verbose { 72 | trainingScore := correctFraction(classifier, trainingFreqs) 73 | log.Printf("Results: cross=%f training=%f support=%d/%d", score, 74 | trainingScore, len(classifier.SampleVectors), countSamples(samples)) 75 | } 76 | if score > bestValidationScore || bestClassifier == nil { 77 | bestClassifier = classifier 78 | } 79 | } 80 | 81 | return bestClassifier 82 | } 83 | 84 | func partitionSamples(data map[string][]tokens.Freqs, crossFrac float64) (cross, 85 | training map[string][]tokens.Freqs) { 86 | 87 | cross = map[string][]tokens.Freqs{} 88 | training = map[string][]tokens.Freqs{} 89 | 90 | for lang, samples := range data { 91 | p := rand.Perm(len(samples)) 92 | newSamples := make([]tokens.Freqs, len(samples)) 93 | for i, x := range p { 94 | newSamples[i] = samples[x] 95 | } 96 | crossCount := int(crossFrac * float64(len(samples))) 97 | cross[lang] = newSamples[:crossCount] 98 | training[lang] = newSamples[crossCount:] 99 | } 100 | 101 | return 102 | } 103 | 104 | func vectorizeSamples(data map[string][]tokens.Freqs) ([]string, map[string][]svm.Sample) { 105 | seenToks := map[string]bool{} 106 | for _, samples := range data { 107 | for _, sample := range samples { 108 | for tok := range sample { 109 | seenToks[tok] = true 110 | } 111 | } 112 | } 113 | toks := make([]string, 0, len(seenToks)) 114 | for tok := range seenToks { 115 | toks = append(toks, tok) 116 | } 117 | 118 | sampleMap := map[string][]svm.Sample{} 119 | sampleID := 1 120 | for lang, samples := range data { 121 | vecSamples := make([]svm.Sample, 0, len(samples)) 122 | for _, sample := range samples { 123 | vec := make([]float64, len(toks)) 124 | for i, tok := range toks { 125 | vec[i] = sample[tok] 126 | } 127 | svmSample := svm.Sample{ 128 | V: vec, 129 | UserInfo: sampleID, 130 | } 131 | sampleID++ 132 | vecSamples = append(vecSamples, svmSample) 133 | } 134 | sampleMap[lang] = vecSamples 135 | } 136 | 137 | return toks, sampleMap 138 | } 139 | 140 | func countSamples(s map[string][]svm.Sample) int { 141 | var count int 142 | for _, samples := range s { 143 | count += len(samples) 144 | } 145 | return count 146 | } 147 | 148 | func cachedKernel(k *Kernel) svm.Kernel { 149 | return svm.CachedKernel(func(s1, s2 svm.Sample) float64 { 150 | return k.Product(linalg.Vector(s1.V), linalg.Vector(s2.V)) 151 | }) 152 | } 153 | 154 | func svmProblem(data map[string][]svm.Sample, posLang string, k svm.Kernel) *svm.Problem { 155 | var positives, negatives []svm.Sample 156 | for lang, samples := range data { 157 | if lang == posLang { 158 | positives = append(positives, samples...) 159 | } else { 160 | negatives = append(negatives, samples...) 161 | } 162 | } 163 | return &svm.Problem{ 164 | Positives: positives, 165 | Negatives: negatives, 166 | Kernel: k, 167 | } 168 | } 169 | 170 | func correctFraction(c *Classifier, data map[string][]tokens.Freqs) float64 { 171 | var correct, total int 172 | for lang, samples := range data { 173 | for _, sample := range samples { 174 | total++ 175 | if c.Classify(sample) == lang { 176 | correct++ 177 | } 178 | } 179 | } 180 | return float64(correct) / float64(total) 181 | } 182 | 183 | func makeSampleVectorList(c *Classifier, used map[int]linalg.Vector) { 184 | userInfoToVecIdx := map[int]int{} 185 | 186 | for userInfo, sample := range used { 187 | userInfoToVecIdx[userInfo] = len(c.SampleVectors) 188 | c.SampleVectors = append(c.SampleVectors, sample) 189 | } 190 | 191 | for _, binClass := range c.Classifiers { 192 | for i, userInfo := range binClass.SupportVectors { 193 | binClass.SupportVectors[i] = userInfoToVecIdx[userInfo] 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /svm/trainer_params.go: -------------------------------------------------------------------------------- 1 | package svm 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "strconv" 7 | ) 8 | 9 | const ( 10 | defaultTradeoff = 1e-5 11 | defaultCrossValidationFraction = 0.3 12 | ) 13 | 14 | var ( 15 | defaultRBFParams = [][]float64{{1e-5}, {1e-4}, {1e-3}, {1e-2}, {1e-1}, {1e0}, {1e1}, {1e2}} 16 | defaultPolyPowers = []float64{2} 17 | defaultPolySums = []float64{0, 1} 18 | ) 19 | 20 | // These environment variables specify 21 | // various parameters for the SVM trainer. 22 | const ( 23 | // Set this to "1" to get verbose logs. 24 | VerboseEnvVar = "SVM_VERBOSE" 25 | 26 | // You may set this to "linear", "rbf", or 27 | // "polynomial". 28 | KernelEnvVar = "SVM_KERNEL" 29 | 30 | // The numerical constant used in the 31 | // RBF kernel. 32 | RBFParamEnvVar = "SVM_RBF_PARAM" 33 | 34 | // The degree parameter for polynomial kernels. 35 | PolyDegreeEnvVar = "SVM_POLY_DEGREE" 36 | 37 | // The summed term (before applying the exponential) 38 | // for polynomial kernels. 39 | PolySumEnvVar = "SVM_POLY_SUM" 40 | 41 | // The tradeoff between margin size and hinge loss. 42 | // The higher the tradeoff value, the greater the 43 | // margin size, but at the expense of correct 44 | // classifications. 45 | TradeoffEnvVar = "SVM_TRADEOFF" 46 | 47 | // The fraction (from 0-1) of samples which are 48 | // used for cross validation. 49 | CrossValidationEnvVar = "SVM_CROSS_VALIDATION" 50 | ) 51 | 52 | // TrainerParams specifies parameters for the 53 | // SVM trainer. 54 | type TrainerParams struct { 55 | Verbose bool 56 | Kernels []*Kernel 57 | Tradeoff float64 58 | 59 | CrossValidation float64 60 | } 61 | 62 | // EnvTrainerParams generates TrainerParams 63 | // by reading environment variables. 64 | // If an environment variable is incorrectly 65 | // formatted, this returns an error. 66 | // When a variable is missing, a default value 67 | // or set of values will be used. 68 | func EnvTrainerParams() (*TrainerParams, error) { 69 | var res TrainerParams 70 | var err error 71 | 72 | if res.Tradeoff, err = envTradeoff(); err != nil { 73 | return nil, err 74 | } 75 | if res.CrossValidation, err = envCrossValidation(); err != nil { 76 | return nil, err 77 | } 78 | res.Verbose = (os.Getenv(VerboseEnvVar) == "1") 79 | 80 | kernTypes, err := envKernelTypes() 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | for _, kernType := range kernTypes { 86 | params, err := envKernelParams(kernType) 87 | if err != nil { 88 | return nil, err 89 | } 90 | for _, param := range params { 91 | kernel := &Kernel{ 92 | Type: kernType, 93 | Params: param, 94 | } 95 | res.Kernels = append(res.Kernels, kernel) 96 | } 97 | } 98 | 99 | return &res, nil 100 | } 101 | 102 | func envTradeoff() (float64, error) { 103 | if val := os.Getenv(TradeoffEnvVar); val != "" { 104 | return strconv.ParseFloat(val, 64) 105 | } else { 106 | return defaultTradeoff, nil 107 | } 108 | } 109 | 110 | func envCrossValidation() (float64, error) { 111 | if val := os.Getenv(CrossValidationEnvVar); val != "" { 112 | return strconv.ParseFloat(val, 64) 113 | } else { 114 | return defaultCrossValidationFraction, nil 115 | } 116 | } 117 | 118 | func envKernelTypes() ([]KernelType, error) { 119 | if val := os.Getenv(KernelEnvVar); val != "" { 120 | res, ok := map[string]KernelType{ 121 | "linear": LinearKernel, 122 | "polynomial": PolynomialKernel, 123 | "rbf": RadialBasisKernel, 124 | }[val] 125 | if !ok { 126 | return nil, errors.New("unknown kernel: " + val) 127 | } else { 128 | return []KernelType{res}, nil 129 | } 130 | } else { 131 | return []KernelType{LinearKernel, PolynomialKernel, RadialBasisKernel}, nil 132 | } 133 | } 134 | 135 | func envKernelParams(t KernelType) ([][]float64, error) { 136 | switch t { 137 | case LinearKernel: 138 | return [][]float64{{}}, nil 139 | case RadialBasisKernel: 140 | if val := os.Getenv(RBFParamEnvVar); val != "" { 141 | res, err := strconv.ParseFloat(val, 64) 142 | if err != nil { 143 | return nil, errors.New("invalid RBF param: " + val) 144 | } 145 | return [][]float64{{res}}, nil 146 | } else { 147 | return defaultRBFParams, nil 148 | } 149 | case PolynomialKernel: 150 | powers := defaultPolyPowers 151 | sums := defaultPolySums 152 | if val := os.Getenv(PolySumEnvVar); val != "" { 153 | sum, err := strconv.ParseFloat(val, 64) 154 | if err != nil { 155 | return nil, errors.New("invalid poly sum: " + val) 156 | } 157 | sums = []float64{sum} 158 | } 159 | if val := os.Getenv(PolyDegreeEnvVar); val != "" { 160 | degree, err := strconv.ParseFloat(val, 64) 161 | if err != nil { 162 | return nil, errors.New("invalid poly degree: " + val) 163 | } 164 | powers = []float64{degree} 165 | } 166 | res := make([][]float64, 0, len(powers)*len(sums)) 167 | for _, power := range powers { 168 | for _, sum := range sums { 169 | res = append(res, []float64{sum, power}) 170 | } 171 | } 172 | return res, nil 173 | default: 174 | panic("unknown kernel: " + strconv.Itoa(int(t))) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /tokens/counts.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | ) 7 | 8 | // Counts records the number of occurrences 9 | // of tokens in a given document. 10 | type Counts map[string]int 11 | 12 | // CountTokens counts the tokens in a document. 13 | // 14 | // Four different types of tokens are detected: 15 | // 16 | // - Heterogeneous tokens: any strings which 17 | // appear surrounded by whitespace. 18 | // - Homogeneous words: strings like "abcd" 19 | // or "123" which are one type of symbol. 20 | // - Line-initial words: both heterogeneous 21 | // and homogeneous words which begin a line. 22 | // These tokens start with "\n". 23 | // - Line-final words: both heterogeneous and 24 | // homogeneous words which end a line. 25 | // These tokens end with "\n". 26 | // 27 | // No homogeneous tokens will be counted as 28 | // heterogeneous tokens, or vice versa. 29 | // All line-boundary words are counted twice, 30 | // once with the newline and once without it. 31 | // If one token makes up an entire line, it is 32 | // counted as both line-initial and line-final. 33 | func CountTokens(contents string) Counts { 34 | res := Counts{} 35 | for _, t := range heterogeneousTokens(contents) { 36 | res[t] += 1 37 | } 38 | for _, t := range homogeneousTokens(contents) { 39 | res[t] += 1 40 | } 41 | for _, t := range lineBoundaryTokens(contents) { 42 | res[t] += 1 43 | } 44 | return res 45 | } 46 | 47 | func lineBoundaryTokens(contents string) []string { 48 | var res []string 49 | 50 | lines := strings.Split(contents, "\n") 51 | for _, line := range lines { 52 | fields := strings.Fields(line) 53 | if len(fields) == 0 { 54 | continue 55 | } 56 | for i, field := range []string{fields[0], fields[len(fields)-1]} { 57 | homog := homogeneousTokens(field) 58 | hetero := heterogeneousTokens(field) 59 | for _, tokList := range [][]string{homog, hetero} { 60 | if len(tokList) > 0 { 61 | if i == 0 { 62 | res = append(res, "\n"+tokList[0]) 63 | } else { 64 | res = append(res, tokList[len(tokList)-1]+"\n") 65 | } 66 | } 67 | } 68 | } 69 | } 70 | 71 | return res 72 | } 73 | 74 | func heterogeneousTokens(contents string) []string { 75 | fields := strings.Fields(contents) 76 | res := make([]string, 0, len(fields)) 77 | for _, f := range fields { 78 | if !isHeterogeneous(f) { 79 | res = append(res, f) 80 | } 81 | } 82 | return res 83 | } 84 | 85 | func homogeneousTokens(contents string) []string { 86 | tokens := []string{} 87 | res := "" 88 | lastClass := charClassSpace 89 | for _, ch := range contents { 90 | c := classForRune(ch) 91 | if c == lastClass { 92 | res += string(ch) 93 | continue 94 | } 95 | if lastClass != charClassSpace && len(res) > 0 { 96 | tokens = append(tokens, res) 97 | } 98 | res = string(ch) 99 | lastClass = c 100 | } 101 | if lastClass != charClassSpace && len(res) > 0 { 102 | tokens = append(tokens, res) 103 | } 104 | return tokens 105 | } 106 | 107 | type charClass int 108 | 109 | const ( 110 | charClassLetter charClass = iota 111 | charClassNumber 112 | charClassSpace 113 | charClassSymbol 114 | ) 115 | 116 | func classForRune(r rune) charClass { 117 | if unicode.IsLetter(r) { 118 | return charClassLetter 119 | } else if unicode.IsDigit(r) { 120 | return charClassNumber 121 | } else if unicode.IsSpace(r) { 122 | return charClassSpace 123 | } 124 | return charClassSymbol 125 | } 126 | 127 | func isHeterogeneous(s string) bool { 128 | if len(s) == 0 { 129 | return true 130 | } 131 | c := classForRune([]rune(s)[0]) 132 | for _, r := range s { 133 | if classForRune(r) != c { 134 | return false 135 | } 136 | } 137 | return true 138 | } 139 | -------------------------------------------------------------------------------- /tokens/counts_test.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import "testing" 4 | 5 | func TestCountTokens(t *testing.T) { 6 | document := "Hello this is a Hello123\ttest\nhello hi is1 is123\nhi!" 7 | actual := CountTokens(document) 8 | expected := map[string]int{ 9 | "Hello123": 1, 10 | "is1": 1, 11 | "is123": 1, 12 | "Hello": 2, 13 | "this": 1, 14 | "is": 3, 15 | "a": 1, 16 | "123": 2, 17 | "test": 1, 18 | "hello": 1, 19 | "hi": 2, 20 | "1": 1, 21 | "!": 1, 22 | "hi!": 1, 23 | 24 | "\nHello": 1, 25 | "test\n": 1, 26 | "\nhello": 1, 27 | "is123\n": 1, 28 | "123\n": 1, 29 | "\nhi": 1, 30 | "\nhi!": 1, 31 | "hi!\n": 1, 32 | "!\n": 1, 33 | } 34 | 35 | for x, count := range expected { 36 | if actual[x] != count { 37 | t.Error("expected count", count, "for", x, "but got", actual[x]) 38 | } 39 | } 40 | 41 | for x := range actual { 42 | if expected[x] == 0 { 43 | t.Error("got unexpected token:", x) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tokens/freqs.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | // Freqs maps words to their frequencies. 4 | // The frequency for a token X equal to 5 | // the number of occurrences of X, divided 6 | // by the total number of tokens in the 7 | // document. 8 | type Freqs map[string]float64 9 | 10 | // Freqs converts word counts into a 11 | // frequency map. 12 | func (c Counts) Freqs() Freqs { 13 | var totalCount int 14 | for _, count := range c { 15 | totalCount += count 16 | } 17 | res := Freqs{} 18 | for word, count := range c { 19 | res[word] = float64(count) / float64(totalCount) 20 | } 21 | return res 22 | } 23 | -------------------------------------------------------------------------------- /tokens/sample_counts.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path/filepath" 7 | "sort" 8 | "strings" 9 | ) 10 | 11 | // SampleCounts maps programming languages to 12 | // arrays of language sample documents, where 13 | // each sample is represented by a Counts. 14 | type SampleCounts map[string][]Counts 15 | 16 | // ReadSampleCounts computes token counts 17 | // for programming language samples in a 18 | // directory. 19 | // 20 | // The directory should contain sub-directories 21 | // for each programming language, and each of 22 | // these languages should contain one or more 23 | // source files. 24 | // 25 | // The returned map maps language names to lists 26 | // of Counts, where each Counts corresponds to 27 | // one source file. 28 | func ReadSampleCounts(sampleDir string) (SampleCounts, error) { 29 | languages, err := readDirectory(sampleDir, true) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | res := SampleCounts{} 35 | for _, language := range languages { 36 | langDir := filepath.Join(sampleDir, language) 37 | files, err := readDirectory(langDir, false) 38 | if err != nil { 39 | return nil, err 40 | } 41 | for _, file := range files { 42 | contents, err := ioutil.ReadFile(filepath.Join(langDir, file)) 43 | if err != nil { 44 | return nil, err 45 | } 46 | counts := CountTokens(string(contents)) 47 | res[language] = append(res[language], counts) 48 | } 49 | } 50 | 51 | return res, nil 52 | } 53 | 54 | // NumTokens returns the number of unique 55 | // tokens in all the documents. 56 | func (s SampleCounts) NumTokens() int { 57 | toks := map[string]bool{} 58 | for _, samples := range s { 59 | for _, sample := range samples { 60 | for word := range sample { 61 | toks[word] = true 62 | } 63 | } 64 | } 65 | return len(toks) 66 | } 67 | 68 | // Prune removes tokens which appear in n 69 | // documents or fewer. 70 | // 71 | // This creates a "" token in each document 72 | // corresponding to the number of pruned 73 | // tokens from that document. 74 | func (s SampleCounts) Prune(n int) { 75 | docCount := map[string]int{} 76 | for _, samples := range s { 77 | for _, sample := range samples { 78 | for word := range sample { 79 | docCount[word]++ 80 | } 81 | } 82 | } 83 | 84 | remove := map[string]bool{} 85 | for word, count := range docCount { 86 | if count <= n { 87 | remove[word] = true 88 | } 89 | } 90 | 91 | for _, samples := range s { 92 | for i, sample := range samples { 93 | newSample := map[string]int{} 94 | removed := 0 95 | for word, count := range sample { 96 | if !remove[word] { 97 | newSample[word] = count 98 | } else { 99 | removed += count 100 | } 101 | } 102 | if removed > 0 { 103 | newSample[""] += removed 104 | } 105 | samples[i] = newSample 106 | } 107 | } 108 | } 109 | 110 | // SampleFreqs converts every Counts object 111 | // in s into a Freqs object. 112 | // The "" key in each Freqs object is deleted 113 | // if one exists. 114 | func (s SampleCounts) SampleFreqs() map[string][]Freqs { 115 | res := map[string][]Freqs{} 116 | for lang, samples := range s { 117 | for _, sample := range samples { 118 | f := sample.Freqs() 119 | delete(f, "") 120 | res[lang] = append(res[lang], f) 121 | } 122 | } 123 | return res 124 | } 125 | 126 | func readDirectory(dir string, isDir bool) ([]string, error) { 127 | f, err := os.Open(dir) 128 | if err != nil { 129 | return nil, err 130 | } 131 | defer f.Close() 132 | contents, err := f.Readdir(-1) 133 | if err != nil { 134 | return nil, err 135 | } 136 | res := make([]string, 0, len(contents)) 137 | for _, info := range contents { 138 | if info.IsDir() == isDir && !strings.HasPrefix(info.Name(), ".") { 139 | res = append(res, info.Name()) 140 | } 141 | } 142 | sort.Strings(res) 143 | return res, nil 144 | } 145 | -------------------------------------------------------------------------------- /tokens/sample_counts_test.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestSampleCountsPruneFreqs(t *testing.T) { 9 | docs := SampleCounts{ 10 | "A": []Counts{ 11 | {"Foo": 1, "Bar": 3, "Baz": 2, "Once1": 1}, 12 | {"Foo": 1, "Bar": 1, "Once2": 15}, 13 | }, 14 | "B": []Counts{ 15 | {"Baz": 15, "Once3": 17}, 16 | }, 17 | } 18 | docs.Prune(1) 19 | actual := docs.SampleFreqs() 20 | expected := map[string][]Freqs{ 21 | "A": []Freqs{ 22 | {"Foo": 1.0 / 7.0, "Bar": 3.0 / 7.0, "Baz": 2.0 / 7.0}, 23 | {"Foo": 1.0 / 17.0, "Bar": 1.0 / 17.0}, 24 | }, 25 | "B": []Freqs{ 26 | {"Baz": 15.0 / 32.0}, 27 | }, 28 | } 29 | for lang, freqs := range expected { 30 | actualFreqs := actual[lang] 31 | if len(actualFreqs) != len(freqs) { 32 | t.Error("unexpected document count for", lang) 33 | continue 34 | } 35 | for i, actualFreq := range actualFreqs { 36 | expFreq := freqs[i] 37 | if !freqsApproxEqual(actualFreq, expFreq) { 38 | t.Error("incorrect freq", actualFreq) 39 | } 40 | } 41 | } 42 | for lang := range actual { 43 | if expected[lang] == nil { 44 | t.Error("unexpected language", lang) 45 | } 46 | } 47 | } 48 | 49 | func freqsApproxEqual(f1, f2 Freqs) bool { 50 | if len(f1) != len(f2) { 51 | return false 52 | } 53 | for key, val := range f1 { 54 | if math.Abs(val-f2[key]) > 1e-5 { 55 | return false 56 | } 57 | } 58 | return true 59 | } 60 | --------------------------------------------------------------------------------