Remove useless wrappers
tls and vncclient close their underlying connection anyways
This commit is contained in:
parent
13c25063af
commit
991ced778f
@ -11,27 +11,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyWrapper struct {
|
||||
Client *ssh.Client
|
||||
|
||||
rawConnection net.Conn
|
||||
sshConnection ssh.Conn
|
||||
}
|
||||
|
||||
func (w ProxyWrapper) Close() {
|
||||
if w.Client != nil {
|
||||
w.Client.Close()
|
||||
}
|
||||
|
||||
if w.sshConnection != nil {
|
||||
w.sshConnection.Close()
|
||||
}
|
||||
|
||||
if w.rawConnection != nil {
|
||||
w.rawConnection.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func GetXenProxyAddress(state multistep.StateBag) (string, error) {
|
||||
proxyAddress, ok := state.GetOk("xen_proxy_address")
|
||||
|
||||
@ -66,7 +45,7 @@ func ConnectViaProxy(proxyAddress, address string) (net.Conn, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ProxyWrapper, error) {
|
||||
func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ssh.Client, error) {
|
||||
connection, err := ConnectViaProxy(proxyAddress, fmt.Sprintf("%s:%d", host, port))
|
||||
|
||||
if err != nil {
|
||||
@ -111,16 +90,16 @@ func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, p
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, sshChan, req)
|
||||
|
||||
wrapper := ProxyWrapper{
|
||||
Client: sshClient,
|
||||
rawConnection: connection,
|
||||
sshConnection: sshConn,
|
||||
}
|
||||
|
||||
return &wrapper, nil
|
||||
return sshClient, nil
|
||||
}
|
||||
|
||||
func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, error) {
|
||||
return CreateCustomPortForwarding(func() (net.Conn, error) {
|
||||
return ConnectViaProxy(proxyAddress, targetAddress)
|
||||
})
|
||||
}
|
||||
|
||||
func CreateCustomPortForwarding(connectTarget func() (net.Conn, error)) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
||||
if err != nil {
|
||||
@ -135,32 +114,22 @@ func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, err
|
||||
continue
|
||||
}
|
||||
|
||||
go handleConnection(proxyAddress, targetAddress, accept)
|
||||
go handleConnection(accept, connectTarget)
|
||||
}
|
||||
}()
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
|
||||
defer accept.Close()
|
||||
conn, err := ConnectViaProxy(proxyAddress, targetAddress)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("[FORWARD] Connect proxy Error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
func serviceForwardedConnection(clientConn net.Conn, targetConn net.Conn) {
|
||||
txDone := make(chan struct{})
|
||||
rxDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(conn, accept)
|
||||
_, err := io.Copy(targetConn, clientConn)
|
||||
|
||||
// Close conn so that other copy operation unblocks
|
||||
conn.Close()
|
||||
targetConn.Close()
|
||||
close(txDone)
|
||||
|
||||
if err != nil {
|
||||
@ -170,10 +139,10 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(accept, conn)
|
||||
_, err := io.Copy(clientConn, targetConn)
|
||||
|
||||
// Close accept so that other copy operation unblocks
|
||||
accept.Close()
|
||||
clientConn.Close()
|
||||
close(rxDone)
|
||||
|
||||
if err != nil {
|
||||
@ -185,3 +154,17 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
|
||||
<-txDone
|
||||
<-rxDone
|
||||
}
|
||||
|
||||
func handleConnection(clientConn net.Conn, connectTarget func() (net.Conn, error)) {
|
||||
defer clientConn.Close()
|
||||
targetConn, err := connectTarget()
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("[FORWARD] Connect proxy Error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer targetConn.Close()
|
||||
|
||||
serviceForwardedConnection(clientConn, targetConn)
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ func ExecuteHostSSHCmd(state multistep.StateBag, cmd string) (stdout string, err
|
||||
|
||||
defer sshClient.Close()
|
||||
|
||||
return doExecuteSSHCmd(cmd, sshClient.Client)
|
||||
return doExecuteSSHCmd(cmd, sshClient)
|
||||
}
|
||||
|
||||
func connectSSH(host string, port int, username string, password string) (*ssh.Client, error) {
|
||||
|
@ -27,7 +27,6 @@ type StepTypeBootCommand struct {
|
||||
func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction {
|
||||
config := state.Get("commonconfig").(CommonConfig)
|
||||
ui := state.Get("ui").(packer.Ui)
|
||||
c := state.Get("client").(*Connection)
|
||||
httpPort := state.Get("http_port").(int)
|
||||
|
||||
var httpIP string
|
||||
@ -42,44 +41,25 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB
|
||||
return multistep.ActionContinue
|
||||
}
|
||||
|
||||
vmRef, err := c.client.VM.GetByNameLabel(c.session, config.VMName)
|
||||
|
||||
location, err := GetVNCConsoleLocation(state)
|
||||
if err != nil {
|
||||
state.Put("error", err)
|
||||
ui.Error(err.Error())
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
if len(vmRef) != 1 {
|
||||
ui.Error(fmt.Sprintf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef)))
|
||||
}
|
||||
|
||||
consoles, err := c.client.VM.GetConsoles(c.session, vmRef[0])
|
||||
if err != nil {
|
||||
state.Put("error", err)
|
||||
ui.Error(err.Error())
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
if len(consoles) != 1 {
|
||||
ui.Error(fmt.Sprintf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles)))
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
location, err := c.client.Console.GetLocation(c.session, consoles[0])
|
||||
|
||||
ui.Say(fmt.Sprintf("Connecting to the VM console VNC over xapi via %s", location))
|
||||
|
||||
vncConnectionWrapper, err := ConnectVNC(state, location)
|
||||
vncClient, err := CreateVNCClient(state, location)
|
||||
|
||||
if err != nil {
|
||||
ui.Error(err.Error())
|
||||
return multistep.ActionHalt
|
||||
}
|
||||
|
||||
defer vncConnectionWrapper.Close()
|
||||
defer vncClient.Close()
|
||||
|
||||
log.Printf("Connected to the VNC console: %s", vncConnectionWrapper.Client.DesktopName)
|
||||
log.Printf("Connected to the VNC console: %s", vncClient.DesktopName)
|
||||
|
||||
self.Ctx.Data = &bootCommandTemplateData{
|
||||
config.VMName,
|
||||
@ -87,7 +67,7 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB
|
||||
uint(httpPort),
|
||||
}
|
||||
|
||||
vncDriver := bootcommand.NewVNCDriver(vncConnectionWrapper.Client, config.VNCConfig.BootKeyInterval)
|
||||
vncDriver := bootcommand.NewVNCDriver(vncClient, config.VNCConfig.BootKeyInterval)
|
||||
|
||||
ui.Say("Typing boot commands over VNC...")
|
||||
|
||||
|
@ -14,87 +14,86 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type VNCConnectionWrapper struct {
|
||||
TlsConn *tls.Conn
|
||||
rawConn net.Conn
|
||||
}
|
||||
|
||||
func (v VNCConnectionWrapper) Close() {
|
||||
if v.TlsConn != nil {
|
||||
_ = v.TlsConn.Close()
|
||||
}
|
||||
|
||||
if v.rawConn != nil {
|
||||
_ = v.rawConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type VNCClientWrapper struct {
|
||||
Client *vnc.ClientConn
|
||||
|
||||
connection *VNCConnectionWrapper
|
||||
}
|
||||
|
||||
func (v VNCClientWrapper) Close() {
|
||||
if v.Client != nil {
|
||||
_ = v.Client.Close()
|
||||
}
|
||||
|
||||
v.connection.Close()
|
||||
}
|
||||
|
||||
func CreateVNCConnection(state multistep.StateBag, location string) (*VNCConnectionWrapper, error) {
|
||||
func GetVNCConsoleLocation(state multistep.StateBag) (string, error) {
|
||||
xenClient := state.Get("client").(*Connection)
|
||||
wrapper := VNCConnectionWrapper{}
|
||||
config := state.Get("commonconfig").(CommonConfig)
|
||||
|
||||
var err error
|
||||
vmRef, err := xenClient.client.VM.GetByNameLabel(xenClient.session, config.VMName)
|
||||
|
||||
target, err := getTcpAddress(location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
wrapper.rawConn, err = ConnectViaXenProxy(state, target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(vmRef) != 1 {
|
||||
return "", fmt.Errorf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef))
|
||||
}
|
||||
|
||||
wrapper.TlsConn, err = httpConnectRequest(location, string(xenClient.GetSessionRef()), wrapper.rawConn)
|
||||
consoles, err := xenClient.client.VM.GetConsoles(xenClient.session, vmRef[0])
|
||||
|
||||
if err != nil {
|
||||
wrapper.Close()
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return &wrapper, nil
|
||||
if len(consoles) != 1 {
|
||||
return "", fmt.Errorf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles))
|
||||
}
|
||||
|
||||
location, err := xenClient.client.Console.GetLocation(xenClient.session, consoles[0])
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return location, nil
|
||||
}
|
||||
|
||||
func ConnectVNC(state multistep.StateBag, location string) (*VNCClientWrapper, error) {
|
||||
wrapper := VNCClientWrapper{}
|
||||
func CreateVNCConnection(state multistep.StateBag, location string) (net.Conn, error) {
|
||||
xenClient := state.Get("client").(*Connection)
|
||||
|
||||
var err error
|
||||
|
||||
wrapper.connection, err = CreateVNCConnection(state, location)
|
||||
target, err := GetTcpAddressFromURL(location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapper.Client, err = vnc.Client(wrapper.connection.TlsConn, &vnc.ClientConfig{
|
||||
rawConn, err := ConnectViaXenProxy(state, target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn, err := initializeVNCConnection(location, string(xenClient.GetSessionRef()), rawConn)
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func CreateVNCClient(state multistep.StateBag, location string) (*vnc.ClientConn, error) {
|
||||
var err error
|
||||
|
||||
connection, err := CreateVNCConnection(state, location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := vnc.Client(connection, &vnc.ClientConfig{
|
||||
Exclusive: false,
|
||||
})
|
||||
if err != nil {
|
||||
wrapper.Close()
|
||||
connection.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &wrapper, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func httpConnectRequest(location string, xenSessionRef string, proxyConnection net.Conn) (*tls.Conn, error) {
|
||||
func initializeVNCConnection(location string, xenSessionRef string, rawConn net.Conn) (*tls.Conn, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
|
||||
tlsConnection := tls.Client(proxyConnection, tlsConfig)
|
||||
tlsConnection := tls.Client(rawConn, tlsConfig)
|
||||
|
||||
request, err := http.NewRequest(http.MethodConnect, location, http.NoBody)
|
||||
|
||||
@ -123,7 +122,7 @@ func httpConnectRequest(location string, xenSessionRef string, proxyConnection n
|
||||
return tlsConnection, nil
|
||||
}
|
||||
|
||||
func getTcpAddress(location string) (string, error) {
|
||||
func GetTcpAddressFromURL(location string) (string, error) {
|
||||
parsedUrl, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
Loading…
Reference in New Issue
Block a user