ns: add interface, use it, and fix thread-related namespace switch issues
authorDan Williams <dcbw@redhat.com>
Tue, 5 Apr 2016 16:10:31 +0000 (11:10 -0500)
committerDan Williams <dcbw@redhat.com>
Fri, 20 May 2016 22:10:25 +0000 (17:10 -0500)
Add a namespace object interface for somewhat cleaner code when
creating and switching between network namespaces.  All created
namespaces are now mounted in /var/run/netns to ensure they
have persistent inodes and paths that can be passed around
between plugin components without relying on the current namespace
being correct.

Also remove the thread-locking arguments from the ns package
per https://github.com/appc/cni/issues/183 by doing all the namespace
changes in a separate goroutine that locks/unlocks itself, instead of
the caller having to track OS thread locking.

19 files changed:
pkg/ip/link.go
pkg/ns/README.md [new file with mode: 0644]
pkg/ns/consts_linux_386.go [deleted file]
pkg/ns/consts_linux_amd64.go [deleted file]
pkg/ns/consts_linux_arm.go [deleted file]
pkg/ns/ns.go
pkg/ns/ns_test.go
pkg/testhelpers/testhelpers.go [deleted file]
pkg/testhelpers/testhelpers_suite_test.go [deleted file]
pkg/testhelpers/testhelpers_test.go [deleted file]
plugins/ipam/dhcp/lease.go
plugins/main/bridge/bridge.go
plugins/main/ipvlan/ipvlan.go
plugins/main/loopback/loopback.go
plugins/main/loopback/loopback_test.go
plugins/main/macvlan/macvlan.go
plugins/main/ptp/ptp.go
plugins/meta/tuning/tuning.go
test

index df16812..1b78567 100644 (file)
@@ -81,7 +81,7 @@ func RandomVethName() (string, error) {
 // SetupVeth sets up a virtual ethernet link.
 // Should be in container netns, and will switch back to hostNS to set the host
 // veth end up.
-func SetupVeth(contVethName string, mtu int, hostNS *os.File) (hostVeth, contVeth netlink.Link, err error) {
+func SetupVeth(contVethName string, mtu int, hostNS ns.NetNS) (hostVeth, contVeth netlink.Link, err error) {
        var hostVethName string
        hostVethName, contVeth, err = makeVeth(contVethName, mtu)
        if err != nil {
@@ -104,10 +104,10 @@ func SetupVeth(contVethName string, mtu int, hostNS *os.File) (hostVeth, contVet
                return
        }
 
-       err = ns.WithNetNS(hostNS, false, func(_ *os.File) error {
+       err = hostNS.Do(func(_ ns.NetNS) error {
                hostVeth, err := netlink.LinkByName(hostVethName)
                if err != nil {
-                       return fmt.Errorf("failed to lookup %q in %q: %v", hostVethName, hostNS.Name(), err)
+                       return fmt.Errorf("failed to lookup %q in %q: %v", hostVethName, hostNS.Path(), err)
                }
 
                if err = netlink.LinkSetUp(hostVeth); err != nil {
diff --git a/pkg/ns/README.md b/pkg/ns/README.md
new file mode 100644 (file)
index 0000000..e7b20c2
--- /dev/null
@@ -0,0 +1,31 @@
+### Namespaces, Threads, and Go
+On Linux each OS thread can have a different network namespace.  Go's thread scheduling model switches goroutines between OS threads based on OS thread load and whether the goroutine would block other goroutines.  This can result in a goroutine switching network namespaces without notice and lead to errors in your code.
+
+### Namespace Switching
+Switching namespaces with the `ns.Set()` method is not recommended without additional strategies to prevent unexpected namespace changes when your goroutines switch OS threads.
+
+Go provides the `runtime.LockOSThread()` function to ensure a specific goroutine executes on its current OS thread and prevents any other goroutine from running in that thread until the locked one exits.  Careful usage of `LockOSThread()` and goroutines can provide good control over which network namespace a given goroutine executes in.
+
+For example, you cannot rely on the `ns.Set()` namespace being the current namespace after the `Set()` call unless you do two things.  First, the goroutine calling `Set()` must have previously called `LockOSThread()`.  Second, you must ensure `runtime.UnlockOSThread()` is not called somewhere in-between.  You also cannot rely on the initial network namespace remaining the current network namespace if any other code in your program switches namespaces, unless you have already called `LockOSThread()` in that goroutine.  Note that `LockOSThread()` prevents the Go scheduler from optimally scheduling goroutines for best performance, so `LockOSThread()` should only be used in small, isolated goroutines that release the lock quickly.
+
+### Do() The Recommended Thing
+The `ns.Do()` method provides control over network namespaces for you by implementing these strategies. All code dependent on a particular network namespace should be wrapped in the `ns.Do()` method to ensure the correct namespace is selected for the duration of your code.  For example:
+
+```go
+targetNs, err := ns.NewNS()
+if err != nil {
+    return err
+}
+err = targetNs.Do(func(hostNs ns.NetNS) error {
+       dummy := &netlink.Dummy{
+               LinkAttrs: netlink.LinkAttrs{
+                       Name: "dummy0",
+               },
+       }
+       return netlink.LinkAdd(dummy)
+})
+```
+
+### Further Reading
+ - https://github.com/golang/go/wiki/LockOSThread
+ - http://morsmachine.dk/go-scheduler
diff --git a/pkg/ns/consts_linux_386.go b/pkg/ns/consts_linux_386.go
deleted file mode 100644 (file)
index fd6ed8a..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2015 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ns
-
-const setNsNr = 346
diff --git a/pkg/ns/consts_linux_amd64.go b/pkg/ns/consts_linux_amd64.go
deleted file mode 100644 (file)
index a86a68a..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2015 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ns
-
-const setNsNr = 308
diff --git a/pkg/ns/consts_linux_arm.go b/pkg/ns/consts_linux_arm.go
deleted file mode 100644 (file)
index 5beaaf3..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2015 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ns
-
-const setNsNr = 375
index 7253399..119a8ce 100644 (file)
 package ns
 
 import (
+       "crypto/rand"
        "fmt"
        "os"
+       "path"
        "runtime"
-       "syscall"
+       "sync"
+
+       "golang.org/x/sys/unix"
 )
 
-// SetNS sets the network namespace on a target file.
-func SetNS(f *os.File, flags uintptr) error {
-       _, _, err := syscall.RawSyscall(setNsNr, f.Fd(), flags, 0)
-       if err != 0 {
-               return err
-       }
+type NetNS interface {
+       // Executes the passed closure in this object's network namespace,
+       // attemtping to restore the original namespace before returning.
+       // However, since each OS thread can have a different network namespace,
+       // and Go's thread scheduling is highly variable, callers cannot
+       // guarantee any specific namespace is set unless operations that
+       // require that namespace are wrapped with Do().  Also, no code called
+       // from Do() should call runtime.UnlockOSThread(), or the risk
+       // of executing code in an incorrect namespace will be greater.  See
+       // https://github.com/golang/go/wiki/LockOSThread for further details.
+       Do(toRun func(NetNS) error) error
 
-       return nil
+       // Sets the current network namespace to this object's network namespace.
+       // Note that since Go's thread scheduling is highly variable, callers
+       // cannot guarantee the requested namespace will be the current namespace
+       // after this function is called; to ensure this wrap operations that
+       // require the namespace with Do() instead.
+       Set() error
+
+       // Returns the filesystem path representing this object's network namespace
+       Path() string
+
+       // Returns a file descriptor representing this object's network namespace
+       Fd() uintptr
+
+       // Cleans up this instance of the network namespace; if this instance
+       // is the last user the namespace will be destroyed
+       Close() error
 }
 
-// WithNetNSPath executes the passed closure under the given network
-// namespace, restoring the original namespace afterwards.
-// Changing namespaces must be done on a goroutine that has been
-// locked to an OS thread. If lockThread arg is true, this function
-// locks the goroutine prior to change namespace and unlocks before
-// returning
-func WithNetNSPath(nspath string, lockThread bool, f func(*os.File) error) error {
-       ns, err := os.Open(nspath)
+type netNS struct {
+       file    *os.File
+       mounted bool
+}
+
+func getCurrentThreadNetNSPath() string {
+       // /proc/self/ns/net returns the namespace of the main thread, not
+       // of whatever thread this goroutine is running on.  Make sure we
+       // use the thread's net namespace since the thread is switching around
+       return fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), unix.Gettid())
+}
+
+// Returns an object representing the current OS thread's network namespace
+func GetCurrentNS() (NetNS, error) {
+       return GetNS(getCurrentThreadNetNSPath())
+}
+
+// Returns an object representing the namespace referred to by @path
+func GetNS(nspath string) (NetNS, error) {
+       fd, err := os.Open(nspath)
        if err != nil {
-               return fmt.Errorf("Failed to open %v: %v", nspath, err)
+               return nil, err
        }
-       defer ns.Close()
-       return WithNetNS(ns, lockThread, f)
+       return &netNS{file: fd}, nil
 }
 
-// WithNetNS executes the passed closure under the given network
-// namespace, restoring the original namespace afterwards.
-// Changing namespaces must be done on a goroutine that has been
-// locked to an OS thread. If lockThread arg is true, this function
-// locks the goroutine prior to change namespace and unlocks before
-// returning.  If the closure returns an error, WithNetNS attempts to
-// restore the original namespace before returning.
-func WithNetNS(ns *os.File, lockThread bool, f func(*os.File) error) error {
-       if lockThread {
-               runtime.LockOSThread()
-               defer runtime.UnlockOSThread()
+// Creates a new persistent network namespace and returns an object
+// representing that namespace, without switching to it
+func NewNS() (NetNS, error) {
+       const nsRunDir = "/var/run/netns"
+
+       b := make([]byte, 16)
+       _, err := rand.Reader.Read(b)
+       if err != nil {
+               return nil, fmt.Errorf("failed to generate random netns name: %v", err)
+       }
+
+       err = os.MkdirAll(nsRunDir, 0755)
+       if err != nil {
+               return nil, err
+       }
+
+       // create an empty file at the mount point
+       nsName := fmt.Sprintf("cni-%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
+       nsPath := path.Join(nsRunDir, nsName)
+       mountPointFd, err := os.Create(nsPath)
+       if err != nil {
+               return nil, err
        }
-       // save a handle to current (host) network namespace
-       thisNS, err := os.Open("/proc/self/ns/net")
+       mountPointFd.Close()
+
+       // Ensure the mount point is cleaned up on errors; if the namespace
+       // was successfully mounted this will have no effect because the file
+       // is in-use
+       defer os.RemoveAll(nsPath)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+
+       // do namespace work in a dedicated goroutine, so that we can safely
+       // Lock/Unlock OSThread without upsetting the lock/unlock state of
+       // the caller of this function
+       var fd *os.File
+       go (func() {
+               defer wg.Done()
+               runtime.LockOSThread()
+
+               var origNS NetNS
+               origNS, err = GetNS(getCurrentThreadNetNSPath())
+               if err != nil {
+                       return
+               }
+               defer origNS.Close()
+
+               // create a new netns on the current thread
+               err = unix.Unshare(unix.CLONE_NEWNET)
+               if err != nil {
+                       return
+               }
+               defer origNS.Set()
+
+               // bind mount the new netns from the current thread onto the mount point
+               err = unix.Mount(getCurrentThreadNetNSPath(), nsPath, "none", unix.MS_BIND, "")
+               if err != nil {
+                       return
+               }
+
+               fd, err = os.Open(nsPath)
+               if err != nil {
+                       return
+               }
+       })()
+       wg.Wait()
+
        if err != nil {
-               return fmt.Errorf("Failed to open /proc/self/ns/net: %v", err)
+               unix.Unmount(nsPath, unix.MNT_DETACH)
+               return nil, fmt.Errorf("failed to create namespace: %v", err)
+       }
+
+       return &netNS{file: fd, mounted: true}, nil
+}
+
+func (ns *netNS) Path() string {
+       return ns.file.Name()
+}
+
+func (ns *netNS) Fd() uintptr {
+       return ns.file.Fd()
+}
+
+func (ns *netNS) Close() error {
+       ns.file.Close()
+
+       if ns.mounted {
+               if err := unix.Unmount(ns.file.Name(), unix.MNT_DETACH); err != nil {
+                       return fmt.Errorf("Failed to unmount namespace %s: %v", ns.file.Name(), err)
+               }
+               if err := os.RemoveAll(ns.file.Name()); err != nil {
+                       return fmt.Errorf("Failed to clean up namespace %s: %v", ns.file.Name(), err)
+               }
+       }
+       return nil
+}
+
+func (ns *netNS) Do(toRun func(NetNS) error) error {
+       containedCall := func(hostNS NetNS) error {
+               threadNS, err := GetNS(getCurrentThreadNetNSPath())
+               if err != nil {
+                       return fmt.Errorf("failed to open current netns: %v", err)
+               }
+               defer threadNS.Close()
+
+               // switch to target namespace
+               if err = ns.Set(); err != nil {
+                       return fmt.Errorf("error switching to ns %v: %v", ns.file.Name(), err)
+               }
+               defer threadNS.Set() // switch back
+
+               return toRun(hostNS)
        }
-       defer thisNS.Close()
 
-       if err = SetNS(ns, syscall.CLONE_NEWNET); err != nil {
-               return fmt.Errorf("Error switching to ns %v: %v", ns.Name(), err)
+       // save a handle to current network namespace
+       hostNS, err := GetNS(getCurrentThreadNetNSPath())
+       if err != nil {
+               return fmt.Errorf("Failed to open current namespace: %v", err)
        }
-       defer SetNS(thisNS, syscall.CLONE_NEWNET) // switch back
+       defer hostNS.Close()
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+
+       var innerError error
+       go func() {
+               defer wg.Done()
+               runtime.LockOSThread()
+               innerError = containedCall(hostNS)
+       }()
+       wg.Wait()
+
+       return innerError
+}
 
-       if err = f(thisNS); err != nil {
-               return err
+func (ns *netNS) Set() error {
+       if _, _, err := unix.Syscall(unix.SYS_SETNS, ns.Fd(), uintptr(unix.CLONE_NEWNET), 0); err != 0 {
+               return fmt.Errorf("Error switching to ns %v: %v", ns.file.Name(), err)
        }
 
        return nil
 }
+
+// WithNetNSPath executes the passed closure under the given network
+// namespace, restoring the original namespace afterwards.
+func WithNetNSPath(nspath string, toRun func(NetNS) error) error {
+       ns, err := GetNS(nspath)
+       if err != nil {
+               return fmt.Errorf("Failed to open %v: %v", nspath, err)
+       }
+       defer ns.Close()
+       return ns.Do(toRun)
+}
index a901eb3..836025e 100644 (file)
@@ -17,109 +17,123 @@ package ns_test
 import (
        "errors"
        "fmt"
-       "math/rand"
        "os"
-       "os/exec"
        "path/filepath"
 
        "github.com/containernetworking/cni/pkg/ns"
-       "github.com/containernetworking/cni/pkg/testhelpers"
        . "github.com/onsi/ginkgo"
        . "github.com/onsi/gomega"
+       "golang.org/x/sys/unix"
 )
 
+func getInodeCurNetNS() (uint64, error) {
+       curNS, err := ns.GetCurrentNS()
+       if err != nil {
+               return 0, err
+       }
+       defer curNS.Close()
+       return getInodeNS(curNS)
+}
+
+func getInodeNS(netns ns.NetNS) (uint64, error) {
+       return getInodeFd(int(netns.Fd()))
+}
+
+func getInode(path string) (uint64, error) {
+       file, err := os.Open(path)
+       if err != nil {
+               return 0, err
+       }
+       defer file.Close()
+       return getInodeFd(int(file.Fd()))
+}
+
+func getInodeFd(fd int) (uint64, error) {
+       stat := &unix.Stat_t{}
+       err := unix.Fstat(fd, stat)
+       return stat.Ino, err
+}
+
 var _ = Describe("Linux namespace operations", func() {
        Describe("WithNetNS", func() {
                var (
-                       targetNetNSName string
-                       targetNetNSPath string
-                       targetNetNS     *os.File
+                       originalNetNS ns.NetNS
+                       targetNetNS   ns.NetNS
                )
 
                BeforeEach(func() {
                        var err error
 
-                       targetNetNSName = fmt.Sprintf("test-netns-%d", rand.Int())
-
-                       err = exec.Command("ip", "netns", "add", targetNetNSName).Run()
+                       originalNetNS, err = ns.NewNS()
                        Expect(err).NotTo(HaveOccurred())
 
-                       targetNetNSPath = filepath.Join("/var/run/netns/", targetNetNSName)
-                       targetNetNS, err = os.Open(targetNetNSPath)
+                       targetNetNS, err = ns.NewNS()
                        Expect(err).NotTo(HaveOccurred())
                })
 
                AfterEach(func() {
                        Expect(targetNetNS.Close()).To(Succeed())
-
-                       err := exec.Command("ip", "netns", "del", targetNetNSName).Run()
-                       Expect(err).NotTo(HaveOccurred())
+                       Expect(originalNetNS.Close()).To(Succeed())
                })
 
                It("executes the callback within the target network namespace", func() {
-                       expectedInode, err := testhelpers.GetInode(targetNetNSPath)
+                       expectedInode, err := getInodeNS(targetNetNS)
                        Expect(err).NotTo(HaveOccurred())
 
-                       var actualInode uint64
-                       var innerErr error
-                       err = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
-                               actualInode, innerErr = testhelpers.GetInodeCurNetNS()
+                       err = targetNetNS.Do(func(ns.NetNS) error {
+                               defer GinkgoRecover()
+
+                               actualInode, err := getInodeCurNetNS()
+                               Expect(err).NotTo(HaveOccurred())
+                               Expect(actualInode).To(Equal(expectedInode))
                                return nil
                        })
                        Expect(err).NotTo(HaveOccurred())
-
-                       Expect(innerErr).NotTo(HaveOccurred())
-                       Expect(actualInode).To(Equal(expectedInode))
                })
 
                It("provides the original namespace as the argument to the callback", func() {
-                       hostNSInode, err := testhelpers.GetInodeCurNetNS()
-                       Expect(err).NotTo(HaveOccurred())
+                       // Ensure we start in originalNetNS
+                       err := originalNetNS.Do(func(ns.NetNS) error {
+                               defer GinkgoRecover()
 
-                       var inputNSInode uint64
-                       var innerErr error
-                       err = ns.WithNetNS(targetNetNS, false, func(inputNS *os.File) error {
-                               inputNSInode, err = testhelpers.GetInodeF(inputNS)
-                               return nil
-                       })
-                       Expect(err).NotTo(HaveOccurred())
-
-                       Expect(innerErr).NotTo(HaveOccurred())
-                       Expect(inputNSInode).To(Equal(hostNSInode))
-               })
+                               origNSInode, err := getInodeNS(originalNetNS)
+                               Expect(err).NotTo(HaveOccurred())
 
-               It("restores the calling thread to the original network namespace", func() {
-                       preTestInode, err := testhelpers.GetInodeCurNetNS()
-                       Expect(err).NotTo(HaveOccurred())
+                               err = targetNetNS.Do(func(hostNS ns.NetNS) error {
+                                       defer GinkgoRecover()
 
-                       err = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                                       hostNSInode, err := getInodeNS(hostNS)
+                                       Expect(err).NotTo(HaveOccurred())
+                                       Expect(hostNSInode).To(Equal(origNSInode))
+                                       return nil
+                               })
                                return nil
                        })
                        Expect(err).NotTo(HaveOccurred())
-
-                       postTestInode, err := testhelpers.GetInodeCurNetNS()
-                       Expect(err).NotTo(HaveOccurred())
-
-                       Expect(postTestInode).To(Equal(preTestInode))
                })
 
                Context("when the callback returns an error", func() {
                        It("restores the calling thread to the original namespace before returning", func() {
-                               preTestInode, err := testhelpers.GetInodeCurNetNS()
-                               Expect(err).NotTo(HaveOccurred())
+                               err := originalNetNS.Do(func(ns.NetNS) error {
+                                       defer GinkgoRecover()
 
-                               _ = ns.WithNetNS(targetNetNS, false, func(*os.File) error {
-                                       return errors.New("potato")
-                               })
+                                       preTestInode, err := getInodeCurNetNS()
+                                       Expect(err).NotTo(HaveOccurred())
 
-                               postTestInode, err := testhelpers.GetInodeCurNetNS()
-                               Expect(err).NotTo(HaveOccurred())
+                                       _ = targetNetNS.Do(func(ns.NetNS) error {
+                                               return errors.New("potato")
+                                       })
 
-                               Expect(postTestInode).To(Equal(preTestInode))
+                                       postTestInode, err := getInodeCurNetNS()
+                                       Expect(err).NotTo(HaveOccurred())
+                                       Expect(postTestInode).To(Equal(preTestInode))
+                                       return nil
+                               })
+                               Expect(err).NotTo(HaveOccurred())
                        })
 
                        It("returns the error from the callback", func() {
-                               err := ns.WithNetNS(targetNetNS, false, func(*os.File) error {
+                               err := targetNetNS.Do(func(ns.NetNS) error {
                                        return errors.New("potato")
                                })
                                Expect(err).To(MatchError("potato"))
@@ -128,16 +142,40 @@ var _ = Describe("Linux namespace operations", func() {
 
                Describe("validating inode mapping to namespaces", func() {
                        It("checks that different namespaces have different inodes", func() {
-                               hostNSInode, err := testhelpers.GetInodeCurNetNS()
+                               origNSInode, err := getInodeNS(originalNetNS)
                                Expect(err).NotTo(HaveOccurred())
 
-                               testNsInode, err := testhelpers.GetInode(targetNetNSPath)
+                               testNsInode, err := getInodeNS(targetNetNS)
                                Expect(err).NotTo(HaveOccurred())
 
-                               Expect(hostNSInode).NotTo(Equal(0))
                                Expect(testNsInode).NotTo(Equal(0))
-                               Expect(testNsInode).NotTo(Equal(hostNSInode))
+                               Expect(testNsInode).NotTo(Equal(origNSInode))
+                       })
+
+                       It("should not leak a closed netns onto any threads in the process", func() {
+                               By("creating a new netns")
+                               createdNetNS, err := ns.NewNS()
+                               Expect(err).NotTo(HaveOccurred())
+
+                               By("discovering the inode of the created netns")
+                               createdNetNSInode, err := getInodeNS(createdNetNS)
+                               Expect(err).NotTo(HaveOccurred())
+                               createdNetNS.Close()
+
+                               By("comparing against the netns inode of every thread in the process")
+                               for _, netnsPath := range allNetNSInCurrentProcess() {
+                                       netnsInode, err := getInode(netnsPath)
+                                       Expect(err).NotTo(HaveOccurred())
+                                       Expect(netnsInode).NotTo(Equal(createdNetNSInode))
+                               }
                        })
                })
        })
 })
+
+func allNetNSInCurrentProcess() []string {
+       pid := unix.Getpid()
+       paths, err := filepath.Glob(fmt.Sprintf("/proc/%d/task/*/ns/net", pid))
+       Expect(err).NotTo(HaveOccurred())
+       return paths
+}
diff --git a/pkg/testhelpers/testhelpers.go b/pkg/testhelpers/testhelpers.go
deleted file mode 100644 (file)
index e9a0fb3..0000000
+++ /dev/null
@@ -1,136 +0,0 @@
-// Copyright 2016 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testhelpers provides common support behavior for tests
-package testhelpers
-
-import (
-       "fmt"
-       "os"
-       "runtime"
-       "sync"
-
-       "golang.org/x/sys/unix"
-
-       . "github.com/onsi/ginkgo"
-       . "github.com/onsi/gomega"
-)
-
-func getCurrentThreadNetNSPath() string {
-       pid := unix.Getpid()
-       tid := unix.Gettid()
-       return fmt.Sprintf("/proc/%d/task/%d/ns/net", pid, tid)
-}
-
-func GetInodeCurNetNS() (uint64, error) {
-       return GetInode(getCurrentThreadNetNSPath())
-}
-
-func GetInode(path string) (uint64, error) {
-       file, err := os.Open(path)
-       if err != nil {
-               return 0, err
-       }
-       defer file.Close()
-       return GetInodeF(file)
-}
-
-func GetInodeF(file *os.File) (uint64, error) {
-       stat := &unix.Stat_t{}
-       err := unix.Fstat(int(file.Fd()), stat)
-       return stat.Ino, err
-}
-
-/*
-A note about goroutines, Linux namespaces and runtime.LockOSThread
-
-In Linux, network namespaces have thread affinity.
-
-In the Go language runtime, goroutines do not have affinity for OS threads.
-The Go runtime scheduler moves goroutines around amongst OS threads.  It
-is supposed to be transparent to the Go programmer.
-
-In order to address cases where the programmer needs thread affinity, Go
-provides runtime.LockOSThread and runtime.UnlockOSThread()
-
-However, the Go runtime does not reference count the Lock and Unlock calls.
-Repeated calls to Lock will succeed, but the first call to Unlock will unlock
-everything.  Therefore, it is dangerous to hide a Lock/Unlock in a library
-function, such as in this package.
-
-The code below, in MakeNetworkNS, avoids this problem by spinning up a new
-Go routine specifically so that LockOSThread can be called on it.  Thus
-goroutine-thread affinity is maintained long enough to perform all the required
-namespace operations.
-
-Because the LockOSThread call is performed inside this short-lived goroutine,
-there is no effect either way on the caller's goroutine-thread affinity.
-
-* */
-
-func MakeNetworkNS(containerID string) string {
-       namespace := "/var/run/netns/" + containerID
-
-       err := os.MkdirAll("/var/run/netns", 0600)
-       Expect(err).NotTo(HaveOccurred())
-
-       // create an empty file at the mount point
-       mountPointFd, err := os.Create(namespace)
-       Expect(err).NotTo(HaveOccurred())
-       mountPointFd.Close()
-
-       var wg sync.WaitGroup
-       wg.Add(1)
-
-       // do namespace work in a dedicated goroutine, so that we can safely
-       // Lock/Unlock OSThread without upsetting the lock/unlock state of
-       // the caller of this function.  See block comment above.
-       go (func() {
-               defer wg.Done()
-
-               runtime.LockOSThread()
-               defer runtime.UnlockOSThread()
-
-               defer GinkgoRecover()
-
-               // capture current thread's original netns
-               currentThreadNetNSPath := getCurrentThreadNetNSPath()
-               originalNetNS, err := unix.Open(currentThreadNetNSPath, unix.O_RDONLY, 0)
-               Expect(err).NotTo(HaveOccurred())
-               defer unix.Close(originalNetNS)
-
-               // create a new netns on the current thread
-               err = unix.Unshare(unix.CLONE_NEWNET)
-               Expect(err).NotTo(HaveOccurred())
-
-               // bind mount the new netns from the current thread onto the mount point
-               err = unix.Mount(currentThreadNetNSPath, namespace, "none", unix.MS_BIND, "")
-               Expect(err).NotTo(HaveOccurred())
-
-               // reset current thread's netns to the original
-               _, _, e1 := unix.Syscall(unix.SYS_SETNS, uintptr(originalNetNS), uintptr(unix.CLONE_NEWNET), 0)
-               Expect(e1).To(BeZero())
-       })()
-
-       wg.Wait()
-
-       return namespace
-}
-
-func RemoveNetworkNS(networkNS string) error {
-       err := unix.Unmount(networkNS, unix.MNT_DETACH)
-
-       err = os.RemoveAll(networkNS)
-       return err
-}
diff --git a/pkg/testhelpers/testhelpers_suite_test.go b/pkg/testhelpers/testhelpers_suite_test.go
deleted file mode 100644 (file)
index 88bfc3d..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2016 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package testhelpers_test
-
-import (
-       "math/rand"
-
-       . "github.com/onsi/ginkgo"
-       "github.com/onsi/ginkgo/config"
-       . "github.com/onsi/gomega"
-
-       "testing"
-)
-
-func TestTesthelpers(t *testing.T) {
-       rand.Seed(config.GinkgoConfig.RandomSeed)
-       RegisterFailHandler(Fail)
-       RunSpecs(t, "Testhelpers Suite")
-}
diff --git a/pkg/testhelpers/testhelpers_test.go b/pkg/testhelpers/testhelpers_test.go
deleted file mode 100644 (file)
index 62d4585..0000000
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright 2016 CNI authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testhelpers_test contains unit tests of the testhelpers
-//
-// Some of this stuff is non-trivial and can interact in surprising ways
-// with the Go runtime.  Better be safe.
-package testhelpers_test
-
-import (
-       "fmt"
-       "math/rand"
-       "path/filepath"
-
-       "golang.org/x/sys/unix"
-
-       "github.com/containernetworking/cni/pkg/testhelpers"
-       . "github.com/onsi/ginkgo"
-       . "github.com/onsi/gomega"
-)
-
-var _ = Describe("Test helper functions", func() {
-       Describe("MakeNetworkNS", func() {
-               It("should return the filepath to a network namespace", func() {
-                       containerID := fmt.Sprintf("c-%x", rand.Int31())
-                       nsPath := testhelpers.MakeNetworkNS(containerID)
-
-                       Expect(nsPath).To(BeAnExistingFile())
-
-                       testhelpers.RemoveNetworkNS(containerID)
-               })
-
-               It("should return a network namespace different from that of the caller", func() {
-                       containerID := fmt.Sprintf("c-%x", rand.Int31())
-
-                       By("discovering the inode of the current netns")
-                       originalNetNSPath := currentNetNSPath()
-                       originalNetNSInode, err := testhelpers.GetInode(originalNetNSPath)
-                       Expect(err).NotTo(HaveOccurred())
-
-                       By("creating a new netns")
-                       createdNetNSPath := testhelpers.MakeNetworkNS(containerID)
-                       defer testhelpers.RemoveNetworkNS(createdNetNSPath)
-
-                       By("discovering the inode of the created netns")
-                       createdNetNSInode, err := testhelpers.GetInode(createdNetNSPath)
-                       Expect(err).NotTo(HaveOccurred())
-
-                       By("comparing the inodes")
-                       Expect(createdNetNSInode).NotTo(Equal(originalNetNSInode))
-               })
-
-               It("should not leak the new netns onto any threads in the process", func() {
-                       containerID := fmt.Sprintf("c-%x", rand.Int31())
-
-                       By("creating a new netns")
-                       createdNetNSPath := testhelpers.MakeNetworkNS(containerID)
-                       defer testhelpers.RemoveNetworkNS(createdNetNSPath)
-
-                       By("discovering the inode of the created netns")
-                       createdNetNSInode, err := testhelpers.GetInode(createdNetNSPath)
-                       Expect(err).NotTo(HaveOccurred())
-
-                       By("comparing against the netns inode of every thread in the process")
-                       for _, netnsPath := range allNetNSInCurrentProcess() {
-                               netnsInode, err := testhelpers.GetInode(netnsPath)
-                               Expect(err).NotTo(HaveOccurred())
-                               Expect(netnsInode).NotTo(Equal(createdNetNSInode))
-                       }
-               })
-       })
-})
-
-func currentNetNSPath() string {
-       pid := unix.Getpid()
-       tid := unix.Gettid()
-       return fmt.Sprintf("/proc/%d/task/%d/ns/net", pid, tid)
-}
-
-func allNetNSInCurrentProcess() []string {
-       pid := unix.Getpid()
-       paths, err := filepath.Glob(fmt.Sprintf("/proc/%d/task/*/ns/net", pid))
-       Expect(err).NotTo(HaveOccurred())
-       return paths
-}
index 82e2fc6..eb2b403 100644 (file)
@@ -19,7 +19,6 @@ import (
        "log"
        "math/rand"
        "net"
-       "os"
        "sync"
        "time"
 
@@ -74,7 +73,7 @@ func AcquireLease(clientID, netns, ifName string) (*DHCPLease, error) {
 
        l.wg.Add(1)
        go func() {
-               errCh <- ns.WithNetNSPath(netns, true, func(_ *os.File) error {
+               errCh <- ns.WithNetNSPath(netns, func(_ ns.NetNS) error {
                        defer l.wg.Done()
 
                        link, err := netlink.LinkByName(ifName)
index dbf0bbb..b866c07 100644 (file)
@@ -19,7 +19,6 @@ import (
        "errors"
        "fmt"
        "net"
-       "os"
        "runtime"
        "syscall"
 
@@ -129,10 +128,10 @@ func ensureBridge(brName string, mtu int) (*netlink.Bridge, error) {
        return br, nil
 }
 
-func setupVeth(netns string, br *netlink.Bridge, ifName string, mtu int, hairpinMode bool) error {
+func setupVeth(netns ns.NetNS, br *netlink.Bridge, ifName string, mtu int, hairpinMode bool) error {
        var hostVethName string
 
-       err := ns.WithNetNSPath(netns, false, func(hostNS *os.File) error {
+       err := netns.Do(func(hostNS ns.NetNS) error {
                // create the veth pair in the container and move host end into host netns
                hostVeth, _, err := ip.SetupVeth(ifName, mtu, hostNS)
                if err != nil {
@@ -191,7 +190,13 @@ func cmdAdd(args *skel.CmdArgs) error {
                return err
        }
 
-       if err = setupVeth(args.Netns, br, args.IfName, n.MTU, n.HairpinMode); err != nil {
+       netns, err := ns.GetNS(args.Netns)
+       if err != nil {
+               return fmt.Errorf("failed to open netns %q: %v", args.Netns, err)
+       }
+       defer netns.Close()
+
+       if err = setupVeth(netns, br, args.IfName, n.MTU, n.HairpinMode); err != nil {
                return err
        }
 
@@ -209,7 +214,7 @@ func cmdAdd(args *skel.CmdArgs) error {
                result.IP4.Gateway = calcGatewayIP(&result.IP4.IP)
        }
 
-       err = ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err = netns.Do(func(_ ns.NetNS) error {
                return ipam.ConfigureIface(args.IfName, result)
        })
        if err != nil {
@@ -254,7 +259,7 @@ func cmdDel(args *skel.CmdArgs) error {
        }
 
        var ipn *net.IPNet
-       err = ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                var err error
                ipn, err = ip.DelLinkByNameAddr(args.IfName, netlink.FAMILY_V4)
                return err
index 6de0cb9..84f9c77 100644 (file)
@@ -18,7 +18,6 @@ import (
        "encoding/json"
        "errors"
        "fmt"
-       "os"
        "runtime"
 
        "github.com/containernetworking/cni/pkg/ip"
@@ -65,7 +64,7 @@ func modeFromString(s string) (netlink.IPVlanMode, error) {
        }
 }
 
-func createIpvlan(conf *NetConf, ifName string, netns *os.File) error {
+func createIpvlan(conf *NetConf, ifName string, netns ns.NetNS) error {
        mode, err := modeFromString(conf.Mode)
        if err != nil {
                return err
@@ -97,7 +96,7 @@ func createIpvlan(conf *NetConf, ifName string, netns *os.File) error {
                return fmt.Errorf("failed to create ipvlan: %v", err)
        }
 
-       return ns.WithNetNS(netns, false, func(_ *os.File) error {
+       return netns.Do(func(_ ns.NetNS) error {
                err := renameLink(tmpName, ifName)
                if err != nil {
                        return fmt.Errorf("failed to rename ipvlan to %q: %v", ifName, err)
@@ -112,9 +111,9 @@ func cmdAdd(args *skel.CmdArgs) error {
                return err
        }
 
-       netns, err := os.Open(args.Netns)
+       netns, err := ns.GetNS(args.Netns)
        if err != nil {
-               return fmt.Errorf("failed to open netns %q: %v", netns, err)
+               return fmt.Errorf("failed to open netns %q: %v", args.Netns, err)
        }
        defer netns.Close()
 
@@ -131,7 +130,7 @@ func cmdAdd(args *skel.CmdArgs) error {
                return errors.New("IPAM plugin returned missing IPv4 config")
        }
 
-       err = ns.WithNetNS(netns, false, func(_ *os.File) error {
+       err = netns.Do(func(_ ns.NetNS) error {
                return ipam.ConfigureIface(args.IfName, result)
        })
        if err != nil {
@@ -153,7 +152,7 @@ func cmdDel(args *skel.CmdArgs) error {
                return err
        }
 
-       return ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       return ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                return ip.DelLinkByName(args.IfName)
        })
 }
index 1e5095d..186fd54 100644 (file)
@@ -15,8 +15,6 @@
 package main
 
 import (
-       "os"
-
        "github.com/containernetworking/cni/pkg/ns"
        "github.com/containernetworking/cni/pkg/skel"
        "github.com/containernetworking/cni/pkg/types"
@@ -25,7 +23,7 @@ import (
 
 func cmdAdd(args *skel.CmdArgs) error {
        args.IfName = "lo" // ignore config, this only works for loopback
-       err := ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err := ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                link, err := netlink.LinkByName(args.IfName)
                if err != nil {
                        return err // not tested
@@ -48,7 +46,7 @@ func cmdAdd(args *skel.CmdArgs) error {
 
 func cmdDel(args *skel.CmdArgs) error {
        args.IfName = "lo" // ignore config, this only works for loopback
-       err := ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err := ns.WithNetNSPath(args.Netns, func(ns.NetNS) error {
                link, err := netlink.LinkByName(args.IfName)
                if err != nil {
                        return err // not tested
index 7c009d1..65b11e0 100644 (file)
@@ -17,12 +17,10 @@ package main_test
 import (
        "fmt"
        "net"
-       "os"
        "os/exec"
        "strings"
 
        "github.com/containernetworking/cni/pkg/ns"
-       "github.com/containernetworking/cni/pkg/testhelpers"
        . "github.com/onsi/ginkgo"
        . "github.com/onsi/gomega"
        "github.com/onsi/gomega/gbytes"
@@ -31,7 +29,7 @@ import (
 
 var _ = Describe("Loopback", func() {
        var (
-               networkNS   string
+               networkNS   ns.NetNS
                containerID string
                command     *exec.Cmd
                environ     []string
@@ -39,12 +37,14 @@ var _ = Describe("Loopback", func() {
 
        BeforeEach(func() {
                command = exec.Command(pathToLoPlugin)
-               containerID = "some-container-id"
-               networkNS = testhelpers.MakeNetworkNS(containerID)
+
+               var err error
+               networkNS, err = ns.NewNS()
+               Expect(err).NotTo(HaveOccurred())
 
                environ = []string{
                        fmt.Sprintf("CNI_CONTAINERID=%s", containerID),
-                       fmt.Sprintf("CNI_NETNS=%s", networkNS),
+                       fmt.Sprintf("CNI_NETNS=%s", networkNS.Path()),
                        fmt.Sprintf("CNI_IFNAME=%s", "this is ignored"),
                        fmt.Sprintf("CNI_ARGS=%s", "none"),
                        fmt.Sprintf("CNI_PATH=%s", "/some/test/path"),
@@ -53,7 +53,7 @@ var _ = Describe("Loopback", func() {
        })
 
        AfterEach(func() {
-               Expect(testhelpers.RemoveNetworkNS(networkNS)).To(Succeed())
+               Expect(networkNS.Close()).To(Succeed())
        })
 
        Context("when given a network namespace", func() {
@@ -67,7 +67,7 @@ var _ = Describe("Loopback", func() {
                        Eventually(session).Should(gexec.Exit(0))
 
                        var lo *net.Interface
-                       err = ns.WithNetNSPath(networkNS, true, func(hostNS *os.File) error {
+                       err = networkNS.Do(func(ns.NetNS) error {
                                var err error
                                lo, err = net.InterfaceByName("lo")
                                return err
@@ -87,7 +87,7 @@ var _ = Describe("Loopback", func() {
                        Eventually(session).Should(gexec.Exit(0))
 
                        var lo *net.Interface
-                       err = ns.WithNetNSPath(networkNS, true, func(hostNS *os.File) error {
+                       err = networkNS.Do(func(ns.NetNS) error {
                                var err error
                                lo, err = net.InterfaceByName("lo")
                                return err
index ce05871..f7eb656 100644 (file)
@@ -18,7 +18,6 @@ import (
        "encoding/json"
        "errors"
        "fmt"
-       "os"
        "runtime"
 
        "github.com/containernetworking/cni/pkg/ip"
@@ -74,7 +73,7 @@ func modeFromString(s string) (netlink.MacvlanMode, error) {
        }
 }
 
-func createMacvlan(conf *NetConf, ifName string, netns *os.File) error {
+func createMacvlan(conf *NetConf, ifName string, netns ns.NetNS) error {
        mode, err := modeFromString(conf.Mode)
        if err != nil {
                return err
@@ -106,7 +105,7 @@ func createMacvlan(conf *NetConf, ifName string, netns *os.File) error {
                return fmt.Errorf("failed to create macvlan: %v", err)
        }
 
-       return ns.WithNetNS(netns, false, func(_ *os.File) error {
+       return netns.Do(func(_ ns.NetNS) error {
                // TODO: duplicate following lines for ipv6 support, when it will be added in other places
                ipv4SysctlValueName := fmt.Sprintf(IPv4InterfaceArpProxySysctlTemplate, tmpName)
                if _, err := sysctl.Sysctl(ipv4SysctlValueName, "1"); err != nil {
@@ -130,7 +129,7 @@ func cmdAdd(args *skel.CmdArgs) error {
                return err
        }
 
-       netns, err := os.Open(args.Netns)
+       netns, err := ns.GetNS(args.Netns)
        if err != nil {
                return fmt.Errorf("failed to open netns %q: %v", netns, err)
        }
@@ -149,7 +148,7 @@ func cmdAdd(args *skel.CmdArgs) error {
                return errors.New("IPAM plugin returned missing IPv4 config")
        }
 
-       err = ns.WithNetNS(netns, false, func(_ *os.File) error {
+       err = netns.Do(func(_ ns.NetNS) error {
                return ipam.ConfigureIface(args.IfName, result)
        })
        if err != nil {
@@ -171,7 +170,7 @@ func cmdDel(args *skel.CmdArgs) error {
                return err
        }
 
-       return ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       return ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                return ip.DelLinkByName(args.IfName)
        })
 }
index f92c49e..ec6e23e 100644 (file)
@@ -58,7 +58,7 @@ func setupContainerVeth(netns, ifName string, mtu int, pr *types.Result) (string
        // In other words we force all traffic to ARP via the gateway except for GW itself.
 
        var hostVethName string
-       err := ns.WithNetNSPath(netns, false, func(hostNS *os.File) error {
+       err := ns.WithNetNSPath(netns, func(hostNS ns.NetNS) error {
                hostVeth, _, err := ip.SetupVeth(ifName, mtu, hostNS)
                if err != nil {
                        return err
@@ -200,7 +200,7 @@ func cmdDel(args *skel.CmdArgs) error {
        }
 
        var ipn *net.IPNet
-       err := ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err := ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                var err error
                ipn, err = ip.DelLinkByNameAddr(args.IfName, netlink.FAMILY_V4)
                return err
index f48d322..75ba852 100644 (file)
@@ -21,7 +21,6 @@ import (
        "encoding/json"
        "fmt"
        "io/ioutil"
-       "os"
        "path/filepath"
        "strings"
 
@@ -45,7 +44,7 @@ func cmdAdd(args *skel.CmdArgs) error {
        // The directory /proc/sys/net is per network namespace. Enter in the
        // network namespace before writing on it.
 
-       err := ns.WithNetNSPath(args.Netns, false, func(hostNS *os.File) error {
+       err := ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
                for key, value := range tuningConf.SysCtl {
                        fileName := filepath.Join("/proc/sys", strings.Replace(key, ".", "/", -1))
                        fileName = filepath.Clean(fileName)
diff --git a/test b/test
index c5171bd..cb5de65 100755 (executable)
--- a/test
+++ b/test
@@ -11,8 +11,8 @@ set -e
 
 source ./build
 
-TESTABLE="plugins/ipam/dhcp plugins/main/loopback pkg/invoke pkg/ns pkg/skel pkg/types pkg/utils pkg/testhelpers"
-FORMATTABLE="$TESTABLE libcni pkg/ip pkg/ns pkg/types pkg/ipam plugins/ipam/host-local plugins/main/bridge plugins/meta/flannel plugins/meta/tuning"
+TESTABLE="plugins/ipam/dhcp plugins/main/loopback pkg/invoke pkg/ns pkg/skel pkg/types pkg/utils"
+FORMATTABLE="$TESTABLE libcni pkg/ip pkg/ipam plugins/ipam/host-local plugins/main/bridge plugins/meta/flannel plugins/meta/tuning"
 
 # user has not provided PKG override
 if [ -z "$PKG" ]; then