pkg/ns: don't allow operations after Close()
authorStefan Junker <mail@stefanjunker.de>
Tue, 24 May 2016 18:27:18 +0000 (20:27 +0200)
committerStefan Junker <mail@stefanjunker.de>
Tue, 24 May 2016 18:52:00 +0000 (20:52 +0200)
ns/ns.go

index 119a8ce..837ab8b 100644 (file)
--- a/ns/ns.go
+++ b/ns/ns.go
@@ -58,6 +58,7 @@ type NetNS interface {
 type netNS struct {
        file    *os.File
        mounted bool
+       closed  bool
 }
 
 func getCurrentThreadNetNSPath() string {
@@ -165,8 +166,22 @@ func (ns *netNS) Fd() uintptr {
        return ns.file.Fd()
 }
 
+func (ns *netNS) errorIfClosed() error {
+       if ns.closed {
+               return fmt.Errorf("%q has already been closed", ns.file.Name())
+       }
+       return nil
+}
+
 func (ns *netNS) Close() error {
-       ns.file.Close()
+       if err := ns.errorIfClosed(); err != nil {
+               return err
+       }
+
+       if err := ns.file.Close(); err != nil {
+               return fmt.Errorf("Failed to close %q: %v", ns.file.Name(), err)
+       }
+       ns.closed = true
 
        if ns.mounted {
                if err := unix.Unmount(ns.file.Name(), unix.MNT_DETACH); err != nil {
@@ -175,11 +190,17 @@ func (ns *netNS) Close() error {
                if err := os.RemoveAll(ns.file.Name()); err != nil {
                        return fmt.Errorf("Failed to clean up namespace %s: %v", ns.file.Name(), err)
                }
+               ns.mounted = false
        }
+
        return nil
 }
 
 func (ns *netNS) Do(toRun func(NetNS) error) error {
+       if err := ns.errorIfClosed(); err != nil {
+               return err
+       }
+
        containedCall := func(hostNS NetNS) error {
                threadNS, err := GetNS(getCurrentThreadNetNSPath())
                if err != nil {
@@ -218,6 +239,10 @@ func (ns *netNS) Do(toRun func(NetNS) error) error {
 }
 
 func (ns *netNS) Set() error {
+       if err := ns.errorIfClosed(); err != nil {
+               return err
+       }
+
        if _, _, err := unix.Syscall(unix.SYS_SETNS, ns.Fd(), uintptr(unix.CLONE_NEWNET), 0); err != 0 {
                return fmt.Errorf("Error switching to ns %v: %v", ns.file.Name(), err)
        }