├── loader ├── test_url2.data ├── test_url1.data ├── test_spider.conf ├── seed_load.go ├── seed_load_test.go ├── config_load.go └── config_load_test.go ├── data └── url.data ├── go.mod ├── conf └── spider.conf ├── crawler ├── crawl_test.go └── crawl.go ├── saver ├── save_data.go └── save_data_test.go ├── README.md ├── main └── mini_spider.go ├── scheduler ├── task.go ├── task_test.go ├── scheduler_test.go └── scheduler.go └── parser ├── parse_test.go └── parse.go /loader/test_url2.data: -------------------------------------------------------------------------------- 1 | [ 2 | ] -------------------------------------------------------------------------------- /data/url.data: -------------------------------------------------------------------------------- 1 | [ 2 | "http://www.baidu.com", 3 | "http://www.sina.com.cn" 4 | ] 5 | -------------------------------------------------------------------------------- /loader/test_url1.data: -------------------------------------------------------------------------------- 1 | [ 2 | "http://www.baidu.com", 3 | "http://www.sina.com.cn", 4 | ] -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coopersong/mini-spider 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/baidu/go-lib v0.0.0-20210316014414-55daa983069e 7 | golang.org/x/net v0.0.0-20201224014010-6772e930b67b 8 | gopkg.in/gcfg.v1 v1.2.3 9 | gopkg.in/warnings.v0 v0.1.2 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /loader/test_spider.conf: -------------------------------------------------------------------------------- 1 | [spider] 2 | # 种子文件路径 3 | urlListFile = ../data/url.data 4 | # 抓取结果存储目录 5 | outputDirectory = ../output 6 | # 最大抓取深度 7 | maxDepth = 0 8 | # 抓取间隔 9 | crawlInterval = 1 10 | # 抓取超时 11 | crawlTimeout = 1 12 | # 需要存储的目标网页URL pattern 13 | targetUrl = .*.(htm|html)$ 14 | # 抓取routine数 15 | threadCount = 8 -------------------------------------------------------------------------------- /conf/spider.conf: -------------------------------------------------------------------------------- 1 | [spider] 2 | # 种子文件路径 3 | urlListFile = ../data/url.data 4 | # 抓取结果存储目录 5 | outputDirectory = ../output 6 | # 最大抓取深度 7 | maxDepth = 2 8 | # 抓取间隔. 单位: 秒 9 | crawlInterval = 1 10 | # 抓取超时. 单位: 秒 11 | crawlTimeout = 1 12 | # 需要存储的目标网页URL pattern 13 | targetUrl = .*.(htm|html)$ 14 | # 抓取routine数 15 | threadCount = 8 16 | -------------------------------------------------------------------------------- /loader/seed_load.go: -------------------------------------------------------------------------------- 1 | package loader 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | ) 8 | 9 | // Load seeds from path. 10 | func SeedLoad(path string) ([]string, error) { 11 | var seeds []string 12 | data, err := ioutil.ReadFile(path) 13 | if err != nil { 14 | return nil, err 15 | } 16 | err = json.Unmarshal(data, &seeds) 17 | if err != nil { 18 | return nil, fmt.Errorf("json.Unmarshal(): %s", err.Error()) 19 | } 20 | if len(seeds) == 0 { 21 | return nil, fmt.Errorf("no seed in %s", path) 22 | } 23 | return seeds, nil 24 | } -------------------------------------------------------------------------------- /crawler/crawl_test.go: -------------------------------------------------------------------------------- 1 | package crawler 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCrawl(t *testing.T) { 8 | // 测试正常的情况 9 | _, _, err := Crawl("https://www.baidu.com", 3) 10 | if err != nil { 11 | t.Errorf("fail to crawl www.baidu.com") 12 | return 13 | } 14 | 15 | // 测试url不合法的情况 16 | _, _, err = Crawl("xxx[]009", 3) 17 | if err == nil { 18 | t.Errorf("crawl invalid url should cause error but not") 19 | return 20 | } 21 | 22 | // 测试无法访问url的情况 23 | _, _, err = Crawl("https://www.guangze.com", 3) 24 | if err == nil { 25 | t.Errorf("no website named https://www.guangze.com but not cause error") 26 | return 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /saver/save_data.go: -------------------------------------------------------------------------------- 1 | package saver 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | // Create a file named urlStr in outputDirectory and save data to it. 11 | func SaveData(data []byte, urlStr string, outputDirectory string) error { 12 | fileName := filepath.Join(outputDirectory, url.QueryEscape(urlStr)) 13 | 14 | f, err := os.OpenFile(fileName, os.O_CREATE | os.O_WRONLY, 0666) 15 | if err != nil { 16 | return fmt.Errorf("%s: os.OpenFile(): %s", fileName, err.Error()) 17 | } 18 | defer f.Close() 19 | 20 | _, err = f.Write(data) 21 | if err != nil { 22 | return fmt.Errorf("f.Write(): %s", err.Error()) 23 | } 24 | 25 | return nil 26 | } -------------------------------------------------------------------------------- /saver/save_data_test.go: -------------------------------------------------------------------------------- 1 | package saver 2 | 3 | import ( 4 | "os/exec" 5 | "testing" 6 | ) 7 | 8 | func TestSaveData(t *testing.T) { 9 | mkCmd := exec.Command("/bin/bash", "-c", "mkdir ../test_output") 10 | rmCmd := exec.Command("/bin/bash", "-c", "rm -rf ../test_output") 11 | err := mkCmd.Start() 12 | if err != nil { 13 | t.Errorf("mkCmd.Start(): %s", err.Error()) 14 | return 15 | } 16 | err = mkCmd.Wait() 17 | if err != nil { 18 | t.Errorf("mkCmd.Wait(): %s", err.Error()) 19 | return 20 | } 21 | defer rmCmd.Start() 22 | 23 | data := []byte("Hello World!") 24 | err = SaveData(data, "www.test.com", "../test_output") 25 | if err != nil { 26 | t.Errorf("SaveData(): %s", err.Error()) 27 | return 28 | } 29 | } -------------------------------------------------------------------------------- /loader/seed_load_test.go: -------------------------------------------------------------------------------- 1 | package loader 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSeedLoad(t *testing.T) { 8 | // 测试文件路径不存在的情况 9 | _, err := SeedLoad("./not_exist.data") 10 | if err == nil { 11 | t.Errorf("./not_exist.data is not exist but there is no error") 12 | return 13 | } 14 | // 测试json解析有问题的情况 15 | _, err = SeedLoad("./test_url1.data") 16 | if err == nil { 17 | t.Errorf("./test_url1.data's json format is invalid but there is no error") 18 | return 19 | } 20 | // 测试种子文件没有种子的情况 21 | _, err = SeedLoad("./test_url2.data") 22 | if err == nil { 23 | t.Errorf("there is no seed in ./test_url2.data but there is no error") 24 | return 25 | } 26 | // 测试正常情况 27 | _, err = SeedLoad("../data/url.data") 28 | if err != nil { 29 | t.Errorf("../data/url.data is valid but there is an error") 30 | return 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 项目名称 2 | 3 | mini-spider 4 | 5 | ## 背景 6 | 7 | 在调研过程中,经常需要对一些网站进行定向抓取。这是一个使用Go语言开发的迷你定向抓取器,实现对种子链接的抓取,并把URL长相符合特定正则表达式的网页保存到磁盘上。 8 | 9 | ## 主要功能 10 | 11 | 实现对种子链接的爬取,并把URL符合特定正则表达式的网页保存到磁盘上。 12 | 13 | ## 快速开始 14 | ### 构建 15 | 16 | 在项目根目录下执行如下命令: 17 | 18 | ```bash 19 | sh build.sh 20 | ``` 21 | 22 | 默认生成三个目录,其中bin用于存放可执行程序,log用于存放日志,output用于存放下载的网页。 23 | 24 | ### 参数 25 | 26 | * -h 显示帮助 27 | * -c 指定配置文件目录 默认../conf 28 | * -l 指定日志文件目录 默认../log 29 | * -v 显示版本 30 | 31 | ## 设计思路 32 | ### 程序初始化 33 | 34 | * 读取命令行参数、初始化日志、读取配置文件、种子文件 35 | * 创建并初始化Scheduler,调用其Start方法开始调度 36 | 37 | ### 调度主要逻辑 38 | 39 | * 退出for循环的条件是任务队列为空且channel为空 40 | * 每次循环,Scheduler从任务队列中取出任务,然后执行任务 41 | 42 | ### 任务执行主要逻辑 43 | 44 | * 判断当前任务的深度,大于等于MaxDepth则返回 45 | * 判断当前任务的URL是否已经爬取过,若是则直接返回,否则开启go routine异步执行任务 46 | * 获取当前任务的域名(站点),检查是否满足爬取间隔的要求 47 | * 根据URL爬取网页,失败则记录日志并返回 48 | * 判断其Content-Type是不是文本,不是则记录日志并返回 49 | * 将爬取到的网页转换成UTF-8格式 50 | * 判断该URL是否满足目标正则表达式,若满足则将其保存至磁盘 51 | * 解析爬取到的网页,将其子URL加入任务队列 52 | 53 | ### 并发控制 54 | 55 | * 最大并发数通过buffered channel控制 56 | 57 | ### 控制抓取间隔 58 | 59 | * 通过sync.Map和time.Timer实现 60 | * sync.Map的Key为hostname,Value为timer 61 | * 每次执行抓取任务前通过任务的URL解析出hostname,通过hostname拿到该站点的timer,等待timer的剩余时间后,重置timer执行抓取任务 62 | 63 | ### 优雅退出 64 | 65 | * 引入useless task保证taskChan在taskQue排空之前排空,详见代码和注释 66 | 67 | ## 测试 68 | 除main包外,每个包下都有单元测试代码,可进行测试。 69 | 70 | 需要注意的是,在scheduler包内,由于多个测试函数都有创建目录、删除目录操作,进行测试时需要限制并发数为1,可执行如下命令进行测试: 71 | 72 | ```bash 73 | go test -parallel 1 74 | ``` 75 | -------------------------------------------------------------------------------- /loader/config_load.go: -------------------------------------------------------------------------------- 1 | package loader 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | ) 7 | 8 | import ( 9 | "gopkg.in/gcfg.v1" 10 | ) 11 | 12 | type Spider struct { 13 | // 种子文件路径 14 | UrlListFile string 15 | // 抓取结果存储目录 16 | OutputDirectory string 17 | // 最大抓取深度 18 | MaxDepth int 19 | // 抓取间隔. 单位: 秒 20 | CrawlInterval int 21 | // 抓取超时. 单位: 秒 22 | CrawlTimeout int 23 | // 需要存储的目标网页URL Pattern 24 | TargetUrl string 25 | // 抓取routine数 26 | ThreadCount int 27 | } 28 | 29 | type Config struct { 30 | Spider 31 | } 32 | 33 | // Load config from confPath. 34 | func ConfigLoad(confPath string) (Config, error) { 35 | var cfg Config 36 | 37 | if err := gcfg.ReadFileInto(&cfg, confPath); err != nil { 38 | return cfg, err 39 | } 40 | 41 | if err := cfg.Check(); err != nil { 42 | return cfg, err 43 | } 44 | 45 | return cfg, nil 46 | } 47 | 48 | // Check config. 49 | func (c *Config) Check() error { 50 | if c.UrlListFile == "" { 51 | return fmt.Errorf("UrlListFile is nil") 52 | } 53 | 54 | if c.OutputDirectory == "" { 55 | return fmt.Errorf("OutputDirectory is nil") 56 | } 57 | 58 | if c.MaxDepth < 1 { 59 | return fmt.Errorf("MaxDepth is less than 1") 60 | } 61 | 62 | if c.CrawlInterval < 0 { 63 | return fmt.Errorf("CrawlInterval is less than 0") 64 | } 65 | 66 | if c.CrawlTimeout < 1 { 67 | return fmt.Errorf("CrawlTimeout is less than 1") 68 | } 69 | 70 | _, err := regexp.Compile(c.TargetUrl) 71 | if err != nil { 72 | return fmt.Errorf("%s: regexp.Compile(): %s", c.TargetUrl, err.Error()) 73 | } 74 | 75 | if c.ThreadCount < 1 { 76 | return fmt.Errorf("ThreadCount is less than 1") 77 | } 78 | 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /main/mini_spider.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "time" 9 | ) 10 | 11 | import ( 12 | "github.com/baidu/go-lib/log" 13 | "github.com/baidu/go-lib/log/log4go" 14 | ) 15 | 16 | import ( 17 | "github.com/coopersong/mini-spider/loader" 18 | "github.com/coopersong/mini-spider/scheduler" 19 | ) 20 | 21 | const ( 22 | Version = "v1.0" 23 | SpiderConfFileName = "spider.conf" 24 | ) 25 | 26 | var ( 27 | confPath = flag.String("c", "../conf", "root path of configuration") 28 | logPath = flag.String("l", "../log", "dir path of log") 29 | help = flag.Bool("h", false, "to show help") 30 | version = flag.Bool("v", false, "to show version") 31 | ) 32 | 33 | func Exit(code int) { 34 | log.Logger.Close() 35 | time.Sleep(100 * time.Millisecond) 36 | os.Exit(code) 37 | } 38 | 39 | func initLog(logSwitch string, logPath *string, stdOut bool) error { 40 | log4go.SetLogBufferLength(10000) 41 | log4go.SetLogWithBlocking(false) 42 | 43 | err := log.Init("mini_spider", logSwitch, *logPath, stdOut, "midnight", 5) 44 | if err != nil { 45 | return fmt.Errorf("err in log.Init(): %s", err.Error()) 46 | } 47 | 48 | return nil 49 | } 50 | 51 | func main() { 52 | var err error 53 | 54 | flag.Parse() 55 | 56 | if *help { 57 | flag.PrintDefaults() 58 | return 59 | } 60 | 61 | if *version { 62 | fmt.Println(Version) 63 | return 64 | } 65 | 66 | err = initLog("INFO", logPath, true) 67 | if err != nil { 68 | fmt.Printf("initLog(): %s\n", err.Error()) 69 | Exit(-1) 70 | } 71 | 72 | config, err := loader.ConfigLoad(filepath.Join(*confPath, SpiderConfFileName)) 73 | if err != nil { 74 | log.Logger.Error("loader.ConfigLoad(): %s", err.Error()) 75 | Exit(-1) 76 | } 77 | 78 | seeds, err := loader.SeedLoad(config.UrlListFile) 79 | if err != nil { 80 | log.Logger.Error("loader.SeedLoad(): %s", err.Error()) 81 | Exit(-1) 82 | } 83 | 84 | miniSpider := scheduler.NewScheduler() 85 | miniSpider.Init(config, seeds) 86 | miniSpider.Start() 87 | 88 | Exit(0) 89 | } -------------------------------------------------------------------------------- /scheduler/task.go: -------------------------------------------------------------------------------- 1 | package scheduler 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | import ( 10 | "github.com/coopersong/mini-spider/crawler" 11 | "github.com/coopersong/mini-spider/parser" 12 | "github.com/coopersong/mini-spider/saver" 13 | ) 14 | 15 | type TaskCommonConfig struct { 16 | // 爬取超时 17 | CrawlTimeout int 18 | // 网页下载目录 19 | OutputDirectory string 20 | // 需要存储的目标网页正则表达式 21 | TargetUrlPattern *regexp.Regexp 22 | } 23 | 24 | type Task struct { 25 | // 爬取url 26 | Url string 27 | // 爬取深度 28 | Depth int 29 | // 通用配置 30 | CommonCfg *TaskCommonConfig 31 | } 32 | 33 | // Create a new useless task. 34 | func NewUselessTask(maxDepth int) *Task { 35 | return &Task{ 36 | Url: "", 37 | Depth: maxDepth, 38 | } 39 | } 40 | 41 | // Run single task. 42 | // A successful call returns sub url list and err == nil. 43 | func (task *Task) Run() ([]string, error) { 44 | data, contentType, err := crawler.Crawl(task.Url, task.CommonCfg.CrawlTimeout) 45 | if err != nil { 46 | return nil, fmt.Errorf("%s: crawler.Crawl(): %s", task.Url, err.Error()) 47 | } 48 | 49 | if !strings.Contains(contentType, "text") { 50 | return nil, fmt.Errorf("%s: Content-Type: %s", task.Url, contentType) 51 | } 52 | 53 | data, err = parser.Convert2Utf8(data, contentType) 54 | if err != nil { 55 | return nil, fmt.Errorf("%s: parser.Convert2Utf8(): %s", task.Url, err.Error()) 56 | } 57 | 58 | if task.CommonCfg.TargetUrlPattern.MatchString(task.Url) { 59 | err = task.SaveData(data) 60 | if err != nil { 61 | return nil, fmt.Errorf("%s: task.SaveData(): %s", task.Url, err.Error()) 62 | } 63 | } 64 | 65 | urlList, err := parser.GetUrlList(data, task.Url) 66 | if err != nil { 67 | return nil, fmt.Errorf("%s: parser.GetUrlList(): %s", task.Url, err.Error()) 68 | } 69 | 70 | return urlList, nil 71 | } 72 | 73 | // Save data to output directory. 74 | func (task *Task) SaveData(data []byte) error { 75 | err := saver.SaveData(data, task.Url, task.CommonCfg.OutputDirectory) 76 | if err != nil { 77 | return fmt.Errorf("saver.SaveData(): %s", err.Error()) 78 | } 79 | 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /scheduler/task_test.go: -------------------------------------------------------------------------------- 1 | package scheduler 2 | 3 | import ( 4 | "os/exec" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestNewUselessTask(t *testing.T) { 10 | uselessTask := NewUselessTask(3) 11 | if uselessTask == nil { 12 | t.Errorf("uselessTask should not be nil but is nil") 13 | return 14 | } 15 | } 16 | 17 | func TestTask_Run(t *testing.T) { 18 | mkCmd := exec.Command("/bin/bash", "-c", "mkdir ../test_output") 19 | rmCmd := exec.Command("/bin/bash", "-c", "rm -rf ../test_output") 20 | 21 | err := mkCmd.Start() 22 | if err != nil { 23 | t.Errorf("mkCmd.Start(): %s", err.Error()) 24 | return 25 | } 26 | defer rmCmd.Start() 27 | 28 | targetUrlPattern, _ := regexp.Compile(".*.(htm|html)$") 29 | task := &Task{ 30 | Url: "http://www.test.com", 31 | Depth: 1, 32 | CommonCfg: &TaskCommonConfig{ 33 | CrawlTimeout: 2, 34 | OutputDirectory: "../test_output", 35 | TargetUrlPattern: targetUrlPattern, 36 | }, 37 | } 38 | 39 | urlList, err := task.Run() 40 | if err == nil || len(urlList) > 0 { 41 | t.Errorf("task.Run() should return error but not") 42 | return 43 | } 44 | 45 | task.Url = "http://www.baidu.com" 46 | _, err = task.Run() 47 | if err != nil { 48 | t.Errorf("task.Run(): %s", err.Error()) 49 | return 50 | } 51 | } 52 | 53 | func TestTask_SaveData(t *testing.T) { 54 | mkCmd := exec.Command("/bin/bash", "-c", "mkdir ../test_output") 55 | rmCmd := exec.Command("/bin/bash", "-c", "rm -rf ../test_output") 56 | 57 | err := mkCmd.Start() 58 | if err != nil { 59 | t.Errorf("mkCmd.Start(): %s", err.Error()) 60 | return 61 | } 62 | defer rmCmd.Start() 63 | 64 | task := &Task{ 65 | Url: "http://www.test.com", 66 | Depth: 1, 67 | CommonCfg: &TaskCommonConfig{ 68 | OutputDirectory: "../test_output", 69 | }, 70 | } 71 | data := []byte("Hello World!") 72 | err = task.SaveData(data) 73 | if err != nil { 74 | t.Errorf("task.SaveData(): %s", err.Error()) 75 | return 76 | } 77 | 78 | task.CommonCfg.OutputDirectory = "../xxx" 79 | err = task.SaveData(data) 80 | if err == nil { 81 | t.Errorf("task.SaveData() should return error but not") 82 | return 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /crawler/crawl.go: -------------------------------------------------------------------------------- 1 | package crawler 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | const ( 11 | HeaderKeyUserAgent = "User-Agent" 12 | ) 13 | 14 | const ( 15 | Mozilla = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/74.0.3729.108 Safari/537.36" 16 | ) 17 | 18 | // Crawl url within timeout. 19 | // A successful call returns data, content type of url and err == nil. 20 | func Crawl(url string, timeout int) ([]byte, string, error) { 21 | var body []byte 22 | var contentType string 23 | var err error 24 | 25 | client := &http.Client{ 26 | Timeout: time.Duration(timeout) * time.Second, 27 | } 28 | 29 | req, err := http.NewRequest("GET", url, nil) 30 | if err != nil { 31 | return nil, "", fmt.Errorf("%s: http.NewRequest(): %s", url, err.Error()) 32 | } 33 | req.Header.Add(HeaderKeyUserAgent, Mozilla) 34 | 35 | timer := time.NewTimer(time.Duration(timeout) * time.Second) 36 | 37 | // errChan用于接收下面这个go routine中可能返回的错误 38 | // 无论是否执行出错 go routine执行结束时一定会往errChan中发送信号 39 | // 因此其同时也标志着go routine是否执行完毕 40 | // 使用有缓冲区的channel并设置缓冲区大小为1是为了出现超时状况时向errChan中发送信号不会阻塞 防止go routine泄漏 41 | errChan := make(chan error, 1) 42 | 43 | go func() { 44 | var resp *http.Response 45 | resp, err = client.Do(req) 46 | if err != nil { 47 | err = fmt.Errorf("%s: client.Do(): %s", url, err.Error()) 48 | errChan <- err 49 | return 50 | } 51 | defer resp.Body.Close() 52 | 53 | if resp.StatusCode != http.StatusOK { 54 | err = fmt.Errorf("%s: status code[%d] not 200", url, resp.StatusCode) 55 | errChan <- err 56 | return 57 | } 58 | 59 | contentType = resp.Header.Get("Content-Type") 60 | 61 | body, err = ioutil.ReadAll(resp.Body) 62 | if err != nil { 63 | err = fmt.Errorf("ioutil.ReadAll(): %s", err.Error()) 64 | errChan <- err 65 | return 66 | } 67 | 68 | errChan <- nil 69 | }() 70 | 71 | // wait until crawl done or timeout 72 | select { 73 | case err = <- errChan: 74 | if err != nil { 75 | return nil, "", err 76 | } 77 | case <- timer.C: 78 | return nil, "", fmt.Errorf("crawl timeout") 79 | } 80 | 81 | return body, contentType, err 82 | } 83 | -------------------------------------------------------------------------------- /parser/parse_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | import ( 8 | "github.com/coopersong/mini-spider/crawler" 9 | ) 10 | 11 | func TestGetUrlList(t *testing.T) { 12 | testUrl := "https://www.baidu.com" 13 | 14 | data, _, err := crawler.Crawl(testUrl, 3) 15 | if err != nil { 16 | t.Errorf("%s: crawler.Crawl(): %s", testUrl, err.Error()) 17 | return 18 | } 19 | 20 | urlList, err := GetUrlList(data, testUrl) 21 | if err != nil { 22 | t.Errorf("GetUrlList(): %s", err.Error()) 23 | return 24 | } 25 | 26 | if len(urlList) == 0 { 27 | t.Errorf("no sublink in %s", testUrl) 28 | return 29 | } 30 | } 31 | 32 | func TestConvert2Utf8(t *testing.T) { 33 | testUrl := "http://vip.stock.finance.sina.com.cn/q/go.php/vDYData/kind/znzd/index.phtml" 34 | 35 | data, contentType, err := crawler.Crawl(testUrl, 3) 36 | if err != nil { 37 | t.Errorf("%s: crawler.Crawl(): %s", testUrl, err.Error()) 38 | return 39 | } 40 | 41 | data, err = Convert2Utf8(data, contentType) 42 | if err != nil { 43 | t.Errorf("Convert2Utf8(): %s", err.Error()) 44 | return 45 | } 46 | } 47 | 48 | func TestParseHostName(t *testing.T) { 49 | rawUrl := "http://xxx.baidu.com/v1/configs" 50 | hostName, err := ParseHostName(rawUrl) 51 | if err != nil { 52 | t.Errorf("%s: ParseHostName(): %s", rawUrl, err.Error()) 53 | return 54 | } 55 | if hostName != "xxx.baidu.com" { 56 | t.Errorf("hostName: %s != xxx.baidu.com", hostName) 57 | return 58 | } 59 | 60 | rawUrl = "http://xxx.baidu.com:8080/v1/configs" 61 | hostName, err = ParseHostName(rawUrl) 62 | if err != nil { 63 | t.Errorf("%s: ParseHostName(): %s", rawUrl, err.Error()) 64 | return 65 | } 66 | if hostName != "xxx.baidu.com" { 67 | t.Errorf("hostName: %s != xxx.baidu.com", hostName) 68 | return 69 | } 70 | 71 | rawUrl = "http:++xxx.baidu.com:8080/v1/configs/page" 72 | hostName, err = ParseHostName(rawUrl) 73 | if err == nil { 74 | t.Errorf("%s: there should be an error but not, hostName: %s", rawUrl, hostName) 75 | return 76 | } 77 | 78 | rawUrl = "http://xxx.baidu.com:8080/v1\n/configs/page" 79 | hostName, err = ParseHostName(rawUrl) 80 | if err == nil { 81 | t.Errorf("%s: there should be an error but not, hostName: %s", rawUrl, hostName) 82 | return 83 | } 84 | } -------------------------------------------------------------------------------- /parser/parse.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net/url" 8 | "strings" 9 | ) 10 | 11 | import ( 12 | "golang.org/x/net/html" 13 | "golang.org/x/net/html/charset" 14 | ) 15 | 16 | // Convert raw to utf8. 17 | func Convert2Utf8(raw []byte, contentType string) ([]byte, error) { 18 | reader := bytes.NewReader(raw) 19 | 20 | utf8Reader, err := charset.NewReader(reader, contentType) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | return ioutil.ReadAll(utf8Reader) 26 | } 27 | 28 | // Get all sub url of node. 29 | func getUrlList(node *html.Node, refUrl *url.URL) []string { 30 | var urlList []string 31 | 32 | if node.Type == html.ElementNode && node.Data == "a" { 33 | for _, a := range node.Attr { 34 | if a.Key == "href" { 35 | if a.Val != "javascript:;" && a.Val != "javascript:void(0)" { 36 | url, err := refUrl.Parse(a.Val) 37 | if err == nil { 38 | urlList = append(urlList, url.String()) 39 | } 40 | } 41 | break 42 | } 43 | } 44 | } 45 | 46 | for child := node.FirstChild; child != nil; child = child.NextSibling { 47 | childUrlList := getUrlList(child, refUrl) 48 | urlList = append(urlList, childUrlList...) 49 | } 50 | 51 | return urlList 52 | } 53 | 54 | // Get sub url list of pre url from data. 55 | func GetUrlList(data []byte, preUrl string) ([]string, error) { 56 | // parse html 57 | node, err := html.Parse(bytes.NewReader(data)) 58 | if err != nil { 59 | return nil, fmt.Errorf("html.Parse(): %s", err.Error()) 60 | } 61 | 62 | // parse url 63 | refUrl, err := url.ParseRequestURI(preUrl) 64 | if err != nil { 65 | return nil, fmt.Errorf("%s: url.ParseRequestURL(): %s", preUrl, err.Error()) 66 | } 67 | 68 | urlList := getUrlList(node, refUrl) 69 | 70 | return urlList, nil 71 | } 72 | 73 | // Parse hostname from raw url. 74 | func ParseHostName(rawUrl string) (string, error) { 75 | u, err := url.Parse(rawUrl) 76 | if err != nil { 77 | return "", err 78 | } 79 | 80 | if u.Host == "" { 81 | return "", fmt.Errorf("empty host") 82 | } 83 | 84 | // 可能出现如xxx.baidu.com:8080这样带端口号的情况 85 | hostName := strings.Split(u.Host, ":") 86 | if len(hostName) == 0 { 87 | return "", fmt.Errorf("invalid hostname") 88 | } 89 | 90 | return hostName[0], nil 91 | } -------------------------------------------------------------------------------- /scheduler/scheduler_test.go: -------------------------------------------------------------------------------- 1 | package scheduler 2 | 3 | import ( 4 | "os/exec" 5 | "testing" 6 | ) 7 | 8 | import ( 9 | "github.com/coopersong/mini-spider/loader" 10 | ) 11 | 12 | func TestNewScheduler(t *testing.T) { 13 | scheduler := NewScheduler() 14 | if scheduler == nil { 15 | t.Errorf("scheduler is nil") 16 | return 17 | } 18 | } 19 | 20 | func TestScheduler_Start(t *testing.T) { 21 | mkCmd := exec.Command("/bin/bash", "-c", "mkdir ../test_output") 22 | rmCmd := exec.Command("/bin/bash", "-c", "rm -rf ../test_output") 23 | 24 | err := mkCmd.Start() 25 | if err != nil { 26 | t.Errorf("mkCmd.Start(): %s", err.Error()) 27 | return 28 | } 29 | err = mkCmd.Wait() 30 | if err != nil { 31 | t.Errorf("mkCmd.Wait(): %s", err.Error()) 32 | } 33 | defer rmCmd.Start() 34 | 35 | scheduler := NewScheduler() 36 | cfg := loader.Config{ 37 | loader.Spider{ 38 | UrlListFile: "../data/url.data", 39 | OutputDirectory: "../test_output", 40 | MaxDepth: 1, 41 | CrawlInterval: 1, 42 | CrawlTimeout: 1, 43 | TargetUrl: ".*.(htm|html)$", 44 | ThreadCount: 8, 45 | }, 46 | } 47 | seeds := []string{"http://www.baidu.com", "http://www.sina.com"} 48 | 49 | scheduler.Init(cfg, seeds) 50 | scheduler.Start() 51 | } 52 | 53 | func TestScheduler_RunTask(t *testing.T) { 54 | mkCmd := exec.Command("/bin/bash", "-c", "mkdir ../test_output") 55 | rmCmd := exec.Command("/bin/bash", "-c", "rm -rf ../test_output") 56 | err := mkCmd.Start() 57 | if err != nil { 58 | t.Errorf("mkCmd.Start(): %s", err.Error()) 59 | return 60 | } 61 | err = mkCmd.Wait() 62 | if err != nil { 63 | t.Errorf("mkCmd.Wait(): %s", err.Error()) 64 | return 65 | } 66 | defer rmCmd.Start() 67 | 68 | scheduler := NewScheduler() 69 | cfg := loader.Config{ 70 | loader.Spider{ 71 | UrlListFile: "../data/url.data", 72 | OutputDirectory: "../test_output", 73 | MaxDepth: 2, 74 | CrawlInterval: 1, 75 | CrawlTimeout: 1, 76 | TargetUrl: ".*.(htm|html)$", 77 | ThreadCount: 8, 78 | }, 79 | } 80 | seeds := []string{"http://www.baidu.com", "http://www.sina.com"} 81 | scheduler.Init(cfg, seeds) 82 | task := &Task{ 83 | Url:"http://www.test.com", 84 | Depth: 0, 85 | CommonCfg: scheduler.TaskCommonCfg, 86 | } 87 | scheduler.RunTask(task) 88 | task.Url = "http://www.baidu.com" 89 | scheduler.RunTask(task) 90 | } 91 | -------------------------------------------------------------------------------- /loader/config_load_test.go: -------------------------------------------------------------------------------- 1 | package loader 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestConfigLoad(t *testing.T) { 9 | // 测试文件不存在的情况 10 | _, err := ConfigLoad("./not_exist.conf") 11 | if err == nil { 12 | t.Errorf("./not_exist.conf is not exist but no error") 13 | return 14 | } 15 | 16 | // 测试文件存在但是数据不合法的情况 17 | _, err = ConfigLoad("./test_spider.conf") 18 | if err == nil { 19 | t.Errorf("./test_spider is invalid but no error") 20 | return 21 | } 22 | 23 | // 测试正常情况 24 | _, err = ConfigLoad("../conf/spider.conf") 25 | if err != nil { 26 | t.Errorf("../conf/spider.conf is valid but there is an error") 27 | return 28 | } 29 | } 30 | 31 | func TestConfig_Check(t *testing.T) { 32 | var testCases = []struct{ 33 | input *Config 34 | want error 35 | }{ 36 | { 37 | &Config{ 38 | Spider{ 39 | UrlListFile: "../data/url.data", 40 | OutputDirectory: "../output", 41 | MaxDepth: 1, 42 | CrawlInterval: 1, 43 | CrawlTimeout: 1, 44 | TargetUrl: ".*.(htm|html)$", 45 | ThreadCount: 8, 46 | }, 47 | }, 48 | nil, 49 | }, 50 | { 51 | &Config{ 52 | Spider{ 53 | UrlListFile: "", 54 | OutputDirectory: "../output", 55 | MaxDepth: 1, 56 | CrawlInterval: 1, 57 | CrawlTimeout: 1, 58 | TargetUrl: ".*.(htm|html)$", 59 | ThreadCount: 8, 60 | }, 61 | }, 62 | fmt.Errorf("there must be an error"), 63 | }, 64 | { 65 | &Config{ 66 | Spider{ 67 | UrlListFile: "../data/url.data", 68 | OutputDirectory: "", 69 | MaxDepth: 1, 70 | CrawlInterval: 1, 71 | CrawlTimeout: 1, 72 | TargetUrl: ".*.(htm|html)$", 73 | ThreadCount: 8, 74 | }, 75 | }, 76 | fmt.Errorf("there must be an error"), 77 | }, 78 | { 79 | &Config{ 80 | Spider{ 81 | UrlListFile: "../data/url.data", 82 | OutputDirectory: "../output", 83 | MaxDepth: 0, 84 | CrawlInterval: 1, 85 | CrawlTimeout: 1, 86 | TargetUrl: ".*.(htm|html)$", 87 | ThreadCount: 8, 88 | }, 89 | }, 90 | fmt.Errorf("there must be an error"), 91 | }, 92 | { 93 | &Config{ 94 | Spider{ 95 | UrlListFile: "../data/url.data", 96 | OutputDirectory: "../output", 97 | MaxDepth: 1, 98 | CrawlInterval: -1, 99 | CrawlTimeout: 1, 100 | TargetUrl: ".*.(htm|html)$", 101 | ThreadCount: 8, 102 | }, 103 | }, 104 | fmt.Errorf("there must be an error"), 105 | }, 106 | { 107 | &Config{ 108 | Spider{ 109 | UrlListFile: "../data/url.data", 110 | OutputDirectory: "../output", 111 | MaxDepth: 1, 112 | CrawlInterval: 1, 113 | CrawlTimeout: 0, 114 | TargetUrl: ".*.(htm|html)$", 115 | ThreadCount: 8, 116 | }, 117 | }, 118 | fmt.Errorf("there must be an error"), 119 | }, 120 | { 121 | &Config{ 122 | Spider{ 123 | UrlListFile: "../data/url.data", 124 | OutputDirectory: "../output", 125 | MaxDepth: 1, 126 | CrawlInterval: 1, 127 | CrawlTimeout: 1, 128 | TargetUrl: ".*.(((htm|html)$", 129 | ThreadCount: 8, 130 | }, 131 | }, 132 | fmt.Errorf("there must be an error"), 133 | }, 134 | { 135 | &Config{ 136 | Spider{ 137 | UrlListFile: "../data/url.data", 138 | OutputDirectory: "../output", 139 | MaxDepth: 1, 140 | CrawlInterval: 1, 141 | CrawlTimeout: 1, 142 | TargetUrl: ".*.(((htm|html)$", 143 | ThreadCount: 0, 144 | }, 145 | }, 146 | fmt.Errorf("there must be an error"), 147 | }, 148 | } 149 | 150 | for index, testCase := range testCases { 151 | err := testCase.input.Check() 152 | if testCase.want != nil { 153 | if err == nil { 154 | t.Errorf("testCases[%d] should cause an error but not", index) 155 | return 156 | } 157 | } else { 158 | if err != nil { 159 | t.Errorf("testCases[%d] should not cause an error but there is an error", index) 160 | return 161 | } 162 | } 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /scheduler/scheduler.go: -------------------------------------------------------------------------------- 1 | package scheduler 2 | 3 | import ( 4 | "regexp" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | import ( 10 | "github.com/baidu/go-lib/log" 11 | "github.com/baidu/go-lib/queue" 12 | ) 13 | 14 | import ( 15 | "github.com/coopersong/mini-spider/loader" 16 | "github.com/coopersong/mini-spider/parser" 17 | ) 18 | 19 | type Scheduler struct { 20 | // 任务队列 21 | TaskQue queue.Queue 22 | // url去重表 23 | UrlTable sync.Map 24 | // 任务channel 25 | TaskChan chan struct{} 26 | // 最大爬取深度 27 | MaxDepth int 28 | // 爬取间隔 单位秒 29 | CrawlInterval int 30 | // 爬取任务所使用的go routine数 31 | ThreadCount int 32 | // 任务通用配置 33 | TaskCommonCfg *TaskCommonConfig 34 | // 站点爬取间隔timer表 35 | TimerTable sync.Map 36 | } 37 | 38 | // Create a new scheduler. 39 | func NewScheduler() *Scheduler { 40 | return new(Scheduler) 41 | } 42 | 43 | // Initialize scheduler by config. 44 | // Initialize scheduler's task queue by seeds. 45 | func (s *Scheduler) Init(config loader.Config, seeds []string) { 46 | // we have checked TargetUrl in config's check, 47 | // so we do not need to check again, 48 | // just ignore the possible error 49 | targetUrlPattern, _ := regexp.Compile(config.TargetUrl) 50 | taskCommonCfg := &TaskCommonConfig{ 51 | CrawlTimeout: config.CrawlTimeout, 52 | OutputDirectory: config.OutputDirectory, 53 | TargetUrlPattern: targetUrlPattern, 54 | } 55 | 56 | // initialize task queue 57 | s.TaskQue.Init() 58 | for _, seed := range seeds { 59 | task := &Task{ 60 | Url: seed, 61 | Depth: 0, 62 | CommonCfg: taskCommonCfg, 63 | } 64 | s.TaskQue.Append(task) 65 | } 66 | 67 | // use buffered channel to control max 68 | // number of concurrent go routines 69 | s.TaskChan = make(chan struct{}, config.ThreadCount) 70 | 71 | s.MaxDepth = config.MaxDepth 72 | 73 | s.CrawlInterval = config.CrawlInterval 74 | 75 | s.ThreadCount = config.ThreadCount 76 | 77 | s.TaskCommonCfg = taskCommonCfg 78 | } 79 | 80 | // Start to run tasks. 81 | func (s *Scheduler) Start() { 82 | log.Logger.Info("start to run tasks") 83 | 84 | for { 85 | if s.TaskQue.Len() == 0 && len(s.TaskChan) == 0 { 86 | // 将新任务加入任务队列这一操作包含在任务中 87 | // len(s.TaskChan) == 0说明后续一定没有新任务被加入到s.TaskQue中 88 | log.Logger.Info("ms.TaskQue has been empty") 89 | break 90 | } 91 | 92 | // s.TaskQue.Len() == 0 && len(s.TaskChan) != 0说明当前还有任务在运行 93 | // 只是还没有生成新的任务加入到任务队列 有了uselessTask的存在 后续一定会有任务被加入到任务队列 94 | // s.TaskQue.Remove()可能会等待一段时间 但不会阻塞 95 | 96 | // s.TaskQue.Len() != 0 && len(s.TaskChan) == 0说明当前没有任务在运行 97 | // s.TaskQue不为空 直接从s.TaskQue里取任务然后运行即可 98 | 99 | // s.TaskQue.Len() != 0 && len(s.TaskChan) != 0说明当前还有任务在运行 100 | // s.TaskQue不为空 直接从s.TaskQue里取任务然后运行即可 101 | task := s.TaskQue.Remove() 102 | s.RunTask(task.(*Task)) 103 | } 104 | 105 | close(s.TaskChan) 106 | log.Logger.Info("all tasks done") 107 | } 108 | 109 | // Run single task. 110 | func (s *Scheduler) RunTask(task *Task) { 111 | if task.Depth >= s.MaxDepth { 112 | return 113 | } 114 | 115 | // 避免重复抓取 116 | // LoadOrStore是Go官方提供的sync.Map的一个方法 第一个参数为key 第二个参数为value 117 | // 如果task.Url已经存在于urlTable中了则返回的ok的值为true 否则ok的值为false并将task.Url加入到urlTable中 118 | if _, ok := s.UrlTable.LoadOrStore(task.Url, true); ok { 119 | // 该url的内容正在抓取或者已经抓取过了 直接返回 120 | return 121 | } 122 | 123 | s.TaskChan <- struct{}{} 124 | go func() { 125 | // uselessTask是为了在任务队列变为空之前排空TaskChan从而优雅退出 uselessTask一进入RunTask方法就会返回不会向TaskChan添加元素 126 | // 有的任务爬虫任务可能不会取到符合条件的子url(可能某个url下没有子url 也可能有子url但子url不能匹配正则表达式) 127 | // 不管有没有符合条件的子url都往任务队列里加一个uselessTask可以保证在Start方法的for循环里遇到 128 | // s.TaskQue.Len() == 0 && len(s.TaskChan) != 0的情况下不会阻塞 129 | uselessTask := NewUselessTask(s.MaxDepth) 130 | 131 | defer func() { 132 | log.Logger.Info("task %s done", task.Url) 133 | // append useless task 134 | s.TaskQue.Append(uselessTask) 135 | <- s.TaskChan 136 | }() 137 | 138 | // 控制抓取间隔 防止被封禁 139 | hostName, err := parser.ParseHostName(task.Url) 140 | if err != nil { 141 | log.Logger.Error("%s: parser.ParseHostName(): %s", task.Url, err.Error()) 142 | return 143 | } 144 | timer, ok := s.TimerTable.LoadOrStore(hostName, time.NewTimer(time.Duration(s.CrawlInterval) * time.Second)) 145 | if ok { 146 | select { 147 | case <- timer.(*time.Timer).C: 148 | } 149 | timer.(*time.Timer).Reset(time.Duration(s.CrawlInterval) * time.Second) 150 | } 151 | 152 | log.Logger.Info("start to crawl %s", task.Url) 153 | urlList, err := task.Run() 154 | if err != nil { 155 | log.Logger.Error("%s", err.Error()) 156 | return 157 | } 158 | 159 | // generate new tasks 160 | for _, url := range urlList { 161 | nextTask := &Task{ 162 | Url: url, 163 | Depth: task.Depth + 1, 164 | CommonCfg: s.TaskCommonCfg, 165 | } 166 | s.TaskQue.Append(nextTask) 167 | } 168 | }() 169 | } --------------------------------------------------------------------------------