-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathcert_checker.go
151 lines (127 loc) · 3.69 KB
/
cert_checker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"net"
"strings"
"time"
"gosrc.io/xmpp/stanza"
)
// TODO: Should I move this as an extension of the client?
// I should probably make the code more modular, but keep concern separated to keep it simple.
type ServerCheck struct {
address string
domain string
}
func NewChecker(address, domain string) (*ServerCheck, error) {
client := ServerCheck{}
var err error
var host string
if client.address, host, err = extractParams(address); err != nil {
return &client, err
}
if domain != "" {
client.domain = domain
} else {
client.domain = host
}
return &client, nil
}
// Check triggers actual TCP connection, based on previously defined parameters.
func (c *ServerCheck) Check() error {
var tcpconn net.Conn
var err error
timeout := 15 * time.Second
tcpconn, err = net.DialTimeout("tcp", c.address, timeout)
if err != nil {
return err
}
decoder := xml.NewDecoder(tcpconn)
// Send stream open tag
if _, err = fmt.Fprintf(tcpconn, clientStreamOpen, c.domain); err != nil {
return err
}
// Set xml decoder and extract streamID from reply (not used for now)
_, err = stanza.InitStream(decoder)
if err != nil {
return err
}
// extract stream features
var f stanza.StreamFeatures
packet, err := stanza.NextPacket(decoder)
if err != nil {
err = fmt.Errorf("stream open decode features: %s", err)
return err
}
switch p := packet.(type) {
case stanza.StreamFeatures:
f = p
case stanza.StreamError:
return errors.New("open stream error: " + p.Error.Local)
default:
return errors.New("expected packet received while expecting features, got " + p.Name())
}
if _, ok := f.DoesStartTLS(); ok {
_, err = fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
if err != nil {
return err
}
var k stanza.TLSProceed
if err = decoder.DecodeElement(&k, nil); err != nil {
return fmt.Errorf("expecting starttls proceed: %s", err)
}
var tlsConfig tls.Config
tlsConfig.ServerName = c.domain
tlsConn := tls.Client(tcpconn, &tlsConfig)
// We convert existing connection to TLS
if err = tlsConn.Handshake(); err != nil {
return err
}
// We check that cert matches hostname
if err = tlsConn.VerifyHostname(c.domain); err != nil {
return err
}
if err = checkExpiration(tlsConn); err != nil {
return err
}
return nil
}
return errors.New("TLS not supported on server")
}
// Check expiration date for the whole certificate chain and returns an error
// if the expiration date is in less than 48 hours.
func checkExpiration(tlsConn *tls.Conn) error {
checkedCerts := make(map[string]struct{})
for _, chain := range tlsConn.ConnectionState().VerifiedChains {
for _, cert := range chain {
if _, checked := checkedCerts[string(cert.Signature)]; checked {
continue
}
checkedCerts[string(cert.Signature)] = struct{}{}
// Check the expiration.
timeNow := time.Now()
expiresInHours := int64(cert.NotAfter.Sub(timeNow).Hours())
// fmt.Printf("Cert '%s' expires in %d days\n", cert.Subject.CommonName, expiresInHours/24)
if expiresInHours <= 48 {
return fmt.Errorf("certificate '%s' will expire on %s", cert.Subject.CommonName, cert.NotAfter)
}
}
}
return nil
}
func extractParams(addr string) (string, string, error) {
var err error
hostport := strings.Split(addr, ":")
if len(hostport) > 2 {
err = errors.New("too many colons in xmpp server address")
return addr, hostport[0], err
}
// Address is composed of two parts, we are good
if len(hostport) == 2 && hostport[1] != "" {
return addr, hostport[0], err
}
// Port was not passed, we append XMPP default port:
return strings.Join([]string{hostport[0], "5222"}, ":"), hostport[0], err
}